Compare commits

..

1 Commits

Author SHA1 Message Date
Jamison Lahman
bfdeb65bbb feat(ux): handle when chat session id cannot be found 2026-03-20 16:44:34 -07:00
91 changed files with 1295 additions and 2914 deletions

View File

@@ -25,6 +25,9 @@ from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
@@ -55,7 +58,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
@@ -71,7 +74,9 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
num_available_tenants = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
num_minimum_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
if num_available_tenants < num_minimum_available_tenants:
@@ -93,12 +98,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
task_logger.exception("Error in check_available_tenants task")
finally:
try:
lock_check.release()
except Exception:
task_logger.warning(
"Could not release check lock (likely expired), continuing"
)
lock_check.release()
def pre_provision_tenant() -> None:
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
@@ -185,9 +185,4 @@ def pre_provision_tenant() -> None:
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
try:
lock_provision.release()
except Exception:
task_logger.warning(
"Could not release provision lock (likely expired), continuing"
)
lock_provision.release()

View File

@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
if not MULTI_TENANT and HOOK_ENABLED:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",

View File

@@ -30,8 +30,6 @@ from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import plaintext_file_name_for_id
from onyx.file_store.utils import store_plaintext
from onyx.kg.models import KGException
from onyx.kg.setup.kg_default_entity_definitions import (
populate_missing_default_entity_types__commit,
@@ -291,33 +289,6 @@ def process_kg_commands(
raise KGException("KG setup done")
def _get_or_extract_plaintext(
file_id: str,
extract_fn: Callable[[], str],
) -> str:
"""Load cached plaintext for a file, or extract and store it.
Tries to read pre-stored plaintext from the file store. On a miss,
calls extract_fn to produce the text, then stores the result so
future calls skip the expensive extraction.
"""
file_store = get_default_file_store()
plaintext_key = plaintext_file_name_for_id(file_id)
# Try cached plaintext first.
try:
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()
if content_text:
store_plaintext(file_id, content_text)
return content_text
@log_function_time(print_only=True)
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
@@ -332,23 +303,12 @@ def load_chat_file(
file_type = ChatFileType(file_descriptor["type"])
if file_type.is_text_file():
file_id = file_descriptor["id"]
def _extract() -> str:
return extract_file_text(
try:
content_text = extract_file_text(
file=file_io,
file_name=file_descriptor.get("name") or "",
break_on_unprocessable=False,
)
# Use the user_file_id as cache key when available (matches what
# the celery indexing worker stores), otherwise fall back to the
# file store id (covers code-interpreter-generated files, etc.).
user_file_id_str = file_descriptor.get("user_file_id")
cache_key = user_file_id_str or file_id
try:
content_text = _get_or_extract_plaintext(cache_key, _extract)
except Exception as e:
logger.warning(
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"

View File

@@ -177,8 +177,8 @@ class ExtractedContextFiles(BaseModel):
class SearchParams(BaseModel):
"""Resolved search filter IDs and search-tool usage for a chat turn."""
project_id_filter: int | None
persona_id_filter: int | None
search_project_id: int | None
search_persona_id: int | None
search_usage: SearchToolUsage

View File

@@ -399,13 +399,13 @@ def determine_search_params(
"""
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
project_id_filter: int | None = None
persona_id_filter: int | None = None
search_project_id: int | None = None
search_persona_id: int | None = None
if extracted_context_files.use_as_search_filter:
if is_custom_persona:
persona_id_filter = persona_id
search_persona_id = persona_id
else:
project_id_filter = project_id
search_project_id = project_id
search_usage = SearchToolUsage.AUTO
if not is_custom_persona and project_id:
@@ -418,8 +418,8 @@ def determine_search_params(
search_usage = SearchToolUsage.DISABLED
return SearchParams(
project_id_filter=project_id_filter,
persona_id_filter=persona_id_filter,
search_project_id=search_project_id,
search_persona_id=search_persona_id,
search_usage=search_usage,
)
@@ -474,18 +474,11 @@ def handle_stream_message_objects(
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
chat_session = get_chat_session_by_id(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
eager_load_persona=True,
)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
eager_load_persona=True,
)
persona = chat_session.persona
@@ -718,8 +711,8 @@ def handle_stream_message_objects(
llm=llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
project_id=search_params.search_project_id,
persona_id=search_params.search_persona_id,
bypass_acl=bypass_acl,
slack_context=slack_context,
enable_slack_search=_should_enable_slack_search(

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
@@ -69,13 +70,9 @@ class BaseFilters(BaseModel):
class UserFileFilters(BaseModel):
# Scopes search to user files tagged with a given project/persona in Vespa.
# These are NOT simply the IDs of the current project or persona — they are
# only set when the persona's/project's user files overflowed the LLM
# context window and must be searched via vector DB instead of being loaded
# directly into the prompt.
project_id_filter: int | None = None
persona_id_filter: int | None = None
user_file_ids: list[UUID] | None = None
project_id: int | None = None
persona_id: int | None = None
class AssistantKnowledgeFilters(BaseModel):

View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from datetime import datetime
from uuid import UUID
from sqlalchemy.orm import Session
@@ -38,8 +39,9 @@ logger = setup_logger()
def _build_index_filters(
user_provided_filters: BaseFilters | None,
user: User, # Used for ACLs, anonymous users only see public docs
project_id_filter: int | None,
persona_id_filter: int | None,
project_id: int | None,
persona_id: int | None,
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
@@ -95,6 +97,16 @@ def _build_index_filters(
if not source_filter and detected_source_filter:
source_filter = detected_source_filter
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
# source type is included in the filter, otherwise user files will be excluded!
if user_file_ids and source_filter:
from onyx.configs.constants import DocumentSource
# Add user_file to the source filter if not already present
if DocumentSource.USER_FILE not in source_filter:
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
@@ -105,8 +117,9 @@ def _build_index_filters(
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
project_id_filter=project_id_filter,
persona_id_filter=persona_id_filter,
user_file_ids=user_file_ids,
project_id=project_id,
persona_id=persona_id,
source_type=source_filter,
document_set=document_set_filter,
time_cutoff=time_filter,
@@ -252,16 +265,19 @@ def search_pipeline(
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# Vespa metadata filters for overflowing user files. NOT the raw IDs
# of the current project/persona — only set when user files couldn't fit
# in the LLM context and need to be searched via vector DB.
project_id_filter: int | None = None,
persona_id_filter: int | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# If a persona_id is provided, search scopes to files attached to this persona
persona_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
)
persona_document_sets: list[str] | None = (
[persona_document_set.name for persona_document_set in persona.document_sets]
if persona
@@ -286,8 +302,9 @@ def search_pipeline(
filters = _build_index_filters(
user_provided_filters=chunk_search_request.user_selected_filters,
user=user,
project_id_filter=project_id_filter,
persona_id_filter=persona_id_filter,
project_id=project_id,
persona_id=persona_id,
user_file_ids=user_uploaded_persona_files,
persona_document_sets=persona_document_sets,
persona_time_cutoff=persona_time_cutoff,
db_session=db_session,

View File

@@ -110,6 +110,7 @@ def search_chunks(
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(

View File

@@ -28,7 +28,6 @@ from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@@ -54,17 +53,9 @@ def get_chat_session_by_id(
db_session: Session,
include_deleted: bool = False,
is_shared: bool = False,
eager_load_persona: bool = False,
) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
if eager_load_persona:
stmt = stmt.options(
selectinload(ChatSession.persona).selectinload(Persona.tools),
selectinload(ChatSession.persona).selectinload(Persona.user_files),
selectinload(ChatSession.project),
)
if is_shared:
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
else:

View File

@@ -2,7 +2,6 @@ import time
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
@@ -150,9 +149,6 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
if DISABLE_VECTOR_DB:
return None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)

View File

@@ -10,8 +10,8 @@ How `IndexFilters` fields combine into the final query filter. Applies to both V
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
| **ACL** | `access_control_list` | OR within, AND with rest |
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
| **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest |
| **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
## How filters combine
@@ -31,22 +31,12 @@ AND time >= cutoff -- if set
The knowledge scope filter controls **what knowledge an assistant can access**.
### Primary vs additive triggers
- **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit
knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is
NOT the raw ID of the persona being used — it is only set when the persona's
user files overflowed the LLM context window.
- **`project_id_filter`** is **additive**. It widens an existing scope to include project
files but never restricts on its own — a chat inside a project should still search
team knowledge when no other knowledge is attached.
### No explicit knowledge attached
When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None:
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
- `project_id_filter` is ignored — it never restricts on its own.
- `project_id` and `persona_id` are ignored — they never restrict on their own.
### One explicit knowledge type
@@ -54,40 +44,39 @@ When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona
-- Only document sets
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
-- Only persona user files (overflowed context)
AND (personas contains 42)
-- Only user files
AND (document_id = "uuid-1" OR document_id = "uuid-2")
```
### Multiple explicit knowledge types (OR'd)
```
-- Document sets + persona user files
-- Document sets + user files
AND (
document_sets contains "Engineering"
OR document_id = "uuid-1"
)
```
### Explicit knowledge + overflowing user files
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
```
-- Document sets + persona user files overflowed
AND (
document_sets contains "Engineering"
OR personas contains 42
)
```
### Explicit knowledge + overflowing project files
When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter:
```
-- Document sets + project files overflowed
-- User files + project files overflowed
AND (
document_sets contains "Engineering"
OR user_project contains 7
)
-- Persona user files + project files (won't happen in practice;
-- custom personas ignore project files per the precedence rule)
AND (
personas contains 42
document_id = "uuid-1"
OR user_project contains 7
)
```
### Only project_id_filter (no explicit knowledge)
### Only project_id or persona_id (no explicit knowledge)
No knowledge scope filter. The assistant searches everything.
@@ -102,10 +91,11 @@ AND (acl contains ...)
| Filter field | Vespa field | Vespa type | Purpose |
|---|---|---|---|
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
| `persona_id_filter` | `personas` | `array<int>` | Persona tag for overflowing user files (**primary** trigger) |
| `project_id_filter` | `user_project` | `array<int>` | Project tag for overflowing project files (**additive** only) |
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
@@ -218,8 +219,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=min_chunk_index,
max_chunk_index=max_chunk_index,
@@ -284,8 +286,9 @@ class DocumentQuery:
source_types=[],
tags=[],
document_sets=[],
project_id_filter=None,
persona_id_filter=None,
user_file_ids=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
@@ -353,8 +356,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -445,8 +449,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -524,8 +529,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -585,8 +591,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -817,8 +824,9 @@ class DocumentQuery:
source_types: list[DocumentSource],
tags: list[Tag],
document_sets: list[str],
project_id_filter: int | None,
persona_id_filter: int | None,
user_file_ids: list[UUID],
project_id: int | None,
persona_id: int | None,
time_cutoff: datetime | None,
min_chunk_index: int | None,
max_chunk_index: int | None,
@@ -849,12 +857,12 @@ class DocumentQuery:
list corresponding to a tag will be retrieved.
document_sets: If supplied, only documents with at least one
document set ID from this list will be retrieved.
project_id_filter: If not None, only documents with this project ID
in user projects will be retrieved. Additive — only applied
when a knowledge scope already exists.
persona_id_filter: If not None, only documents whose personas array
contains this persona ID will be retrieved. Primary — creates
a knowledge scope on its own.
user_file_ids: If supplied, only document IDs in this list will be
retrieved.
project_id: If not None, only documents with this project ID in user
projects will be retrieved.
persona_id: If not None, only documents whose personas array
contains this persona ID will be retrieved.
time_cutoff: Time cutoff for the documents to retrieve. If not None,
Documents which were last updated before this date will not be
returned. For documents which do not have a value for their last
@@ -871,6 +879,10 @@ class DocumentQuery:
NOTE: See DocumentChunk.max_chunk_size.
document_id: The document ID to retrieve. If None, no filter will be
applied for this. Defaults to None.
WARNING: This filters on the same property as user_file_ids.
Although it would never make sense to supply both, note that if
user_file_ids is supplied and does not contain document_id, no
matches will be retrieved.
attached_document_ids: Document IDs explicitly attached to the
assistant. If provided along with hierarchy_node_ids, documents
matching EITHER criteria will be retrieved (OR logic).
@@ -931,6 +943,15 @@ class DocumentQuery:
)
return document_set_filter
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
# Logical OR operator on its elements.
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
for user_file_id in user_file_ids:
user_file_id_filter["bool"]["should"].append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
)
return user_file_id_filter
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
# Logical OR operator on its elements.
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
@@ -1031,17 +1052,14 @@ class DocumentQuery:
# assistant can see. When none are set the assistant searches
# everything.
#
# persona_id_filter is a primary trigger — a persona with user files IS
# explicit knowledge, so it can start a knowledge scope on its own.
#
# project_id_filter is additive — it widens the scope to also cover
# overflowing project files but never restricts on its own (a chat
# inside a project should still search team knowledge).
# project_id / persona_id are additive: they make overflowing user files
# findable but must NOT trigger the restriction on their own (an agent
# with no explicit knowledge should search everything).
has_knowledge_scope = (
attached_document_ids
or hierarchy_node_ids
or user_file_ids
or document_sets
or persona_id_filter is not None
)
if has_knowledge_scope:
@@ -1056,17 +1074,23 @@ class DocumentQuery:
knowledge_filter["bool"]["should"].append(
_get_hierarchy_node_filter(hierarchy_node_ids)
)
if user_file_ids:
knowledge_filter["bool"]["should"].append(
_get_user_file_id_filter(user_file_ids)
)
if document_sets:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
if persona_id_filter is not None:
# Additive: widen scope to also cover overflowing user files, but
# only when an explicit restriction is already in effect.
if project_id is not None:
knowledge_filter["bool"]["should"].append(
_get_persona_filter(persona_id_filter)
_get_user_project_filter(project_id)
)
if project_id_filter is not None:
if persona_id is not None:
knowledge_filter["bool"]["should"].append(
_get_user_project_filter(project_id_filter)
_get_persona_filter(persona_id)
)
filter_clauses.append(knowledge_filter)
@@ -1084,6 +1108,8 @@ class DocumentQuery:
)
if document_id is not None:
# WARNING: If user_file_ids has elements and if none of them are
# document_id, no matches will be retrieved.
filter_clauses.append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
)

View File

@@ -199,29 +199,31 @@ def build_vespa_filters(
]
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
# Knowledge scope: explicit knowledge attachments restrict what an
# assistant can see. When none are set, the assistant can see
# everything.
# Knowledge scope: explicit knowledge attachments (document_sets,
# user_file_ids) restrict what an assistant can see. When none are
# set, the assistant can see everything.
#
# persona_id_filter is a primary trigger — a persona with user files IS
# explicit knowledge, so it can start a knowledge scope on its own.
#
# project_id_filter is additive — it widens the scope to also cover
# overflowing project files but never restricts on its own (a chat
# inside a project should still search team knowledge).
# project_id / persona_id are additive: they make overflowing user
# files findable in Vespa but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
knowledge_scope_parts: list[str] = []
_append(
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
)
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter))
# project_id_filter only widens an existing scope.
user_file_ids_str = (
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
)
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
# Only include project/persona scopes when an explicit knowledge
# restriction is already in effect — they widen the scope to also
# cover overflowing user files but never restrict on their own.
if knowledge_scope_parts:
_append(
knowledge_scope_parts,
_build_user_project_filter(filters.project_id_filter),
)
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
if len(knowledge_scope_parts) > 1:
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")

View File

@@ -88,7 +88,6 @@ class OnyxErrorCode(Enum):
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:

View File

@@ -38,7 +38,17 @@ def get_federated_retrieval_functions(
source_types: list[DocumentSource] | None,
document_set_names: list[str] | None,
slack_context: SlackContext | None = None,
user_file_ids: list[UUID] | None = None,
) -> list[FederatedRetrievalInfo]:
# When User Knowledge (user files) is the only knowledge source enabled,
# skip federated connectors entirely. User Knowledge mode means the agent
# should ONLY use uploaded files, not team connectors like Slack.
if user_file_ids and not document_set_names:
logger.debug(
"Skipping all federated connectors: User Knowledge mode enabled "
f"with {len(user_file_ids)} user files and no document sets"
)
return []
# Check for Slack bot context first (regardless of user_id)
if slack_context:

View File

@@ -23,55 +23,45 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
def plaintext_file_name_for_id(file_id: str) -> str:
"""Generate a consistent file name for storing plaintext content of a file."""
return f"plaintext_{file_id}"
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return f"plaintext_{user_file_id}"
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""
Store plaintext content for a file in the file store.
Store plaintext content for a user file in the file store.
Args:
file_id: The ID of the file (user_file or artifact_file)
user_file_id: The ID of the user file
plaintext_content: The plaintext content to store
Returns:
bool: True if storage was successful, False otherwise
"""
# Skip empty content
if not plaintext_content:
return False
plaintext_file_name = plaintext_file_name_for_id(file_id)
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for {file_id}",
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
return False
# --- Convenience wrappers for callers that use user-file UUIDs ---
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return plaintext_file_name_for_id(str(user_file_id))
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
return store_plaintext(str(user_file_id), plaintext_content)
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
"""Load a file directly from the file store using its file_record ID.

View File

@@ -1,330 +0,0 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is the response payload dict from the customer's endpoint
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
import httpx
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
"""No active hook configured for this hook point."""
class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if outcome.response_payload is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
)
return outcome.response_payload

View File

@@ -42,8 +42,12 @@ class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None
timeout_seconds: float | None = Field(default=None, gt=0)
fail_strategy: HookFailStrategy | None = (
None # if None in model_fields_set, reset to spec default
)
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None in model_fields_set, reset to spec default
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
@@ -56,14 +60,6 @@ class HookUpdateRequest(BaseModel):
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
raise ValueError(
"fail_strategy cannot be null; omit the field to leave it unchanged."
)
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
raise ValueError(
"timeout_seconds cannot be null; omit the field to leave it unchanged."
)
return self
@@ -94,28 +90,38 @@ class HookResponse(BaseModel):
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
is_reachable: bool | None
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateStatus(str, Enum):
passed = "passed" # server responded (any status except 401/403)
auth_failed = "auth_failed" # server responded with 401 or 403
timeout = (
"timeout" # TCP connected, but read/write timed out (server exists but slow)
)
cannot_connect = "cannot_connect" # could not connect to the server
class HookValidateResponse(BaseModel):
status: HookValidateStatus
success: bool
error_message: str | None = None
class HookExecutionRecord(BaseModel):
# ---------------------------------------------------------------------------
# Health models
# ---------------------------------------------------------------------------
class HookHealthStatus(str, Enum):
healthy = "healthy" # green — reachable, no failures in last 1h
degraded = "degraded" # yellow — reachable, failures in last 1h
unreachable = "unreachable" # red — is_reachable=false or null
class HookFailureRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime
class HookHealthResponse(BaseModel):
status: HookHealthStatus
recent_failures: list[HookFailureRecord] = Field(
default_factory=list,
description="Last 10 failures, newest first",
max_length=10,
)

View File

@@ -1,7 +1,6 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import ClassVar
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -14,25 +13,22 @@ _REQUIRED_ATTRS = (
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
"payload_model",
"response_model",
)
class HookPointSpec:
class HookPointSpec(ABC):
"""Static metadata and contract for a pipeline hook point.
This is NOT a regular class meant for direct instantiation by callers.
Each concrete subclass represents exactly one hook point and is instantiated
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
get_hook_point_spec() or get_all_specs() from the registry over direct
instantiation.
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
should ever create instances directly — use get_hook_point_spec() or
get_all_specs() from the registry instead.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Subclasses must define all attributes as class-level constants.
payload_model and response_model must be Pydantic BaseModel subclasses;
input_schema and output_schema are derived from them automatically.
"""
hook_point: HookPoint
@@ -43,33 +39,21 @@ class HookPointSpec:
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
payload_model: ClassVar[type[BaseModel]]
response_model: ClassVar[type[BaseModel]]
# Computed once at class definition time from payload_model / response_model.
input_schema: ClassVar[dict[str, Any]]
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every concrete subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
"""
super().__init_subclass__(**kwargs)
# Skip intermediate abstract subclasses — they may still be partially defined.
if getattr(cls, "__abstractmethods__", None):
return
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
for attr in ("payload_model", "response_model"):
val = getattr(cls, attr, None)
if val is None or not (
isinstance(val, type) and issubclass(val, BaseModel)
):
raise TypeError(
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
)
cls.input_schema = cls.payload_model.model_json_schema()
cls.output_schema = cls.response_model.model_json_schema()
@property
@abstractmethod
def input_schema(self) -> dict[str, Any]:
"""JSON schema describing the request payload sent to the customer's endpoint."""
@property
@abstractmethod
def output_schema(self) -> dict[str, Any]:
"""JSON schema describing the expected response from the customer's endpoint."""

View File

@@ -1,19 +1,10 @@
from pydantic import BaseModel
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
# TODO(@Bo-Onyx): define payload and response fields
class DocumentIngestionPayload(BaseModel):
pass
class DocumentIngestionResponse(BaseModel):
pass
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
@@ -27,5 +18,12 @@ class DocumentIngestionSpec(HookPointSpec):
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse
@property
def input_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define input schema
return {"type": "object", "properties": {}}
@property
def output_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define output schema
return {"type": "object", "properties": {}}

View File

@@ -1,39 +1,10 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
query: str = Field(description="The raw query string exactly as the user typed it.")
user_email: str | None = Field(
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
)
class QueryProcessingResponse(BaseModel):
# Intentionally permissive — customer endpoints may return extra fields.
query: str | None = Field(
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, or absent = reject the query."
),
)
rejection_message: str | None = Field(
default=None,
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
)
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
@@ -66,5 +37,47 @@ class QueryProcessingSpec(HookPointSpec):
)
default_fail_strategy = HookFailStrategy.HARD
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse
@property
def input_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The raw query string exactly as the user typed it.",
},
"user_email": {
"type": ["string", "null"],
"description": "Email of the user submitting the query, or null if unauthenticated.",
},
"chat_session_id": {
"type": "string",
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
},
},
"required": ["query", "user_email", "chat_session_id"],
"additionalProperties": False,
}
@property
def output_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": ["string", "null"],
"description": (
"The (optionally modified) query to use. "
"Set to null to reject the query."
),
},
"rejection_message": {
"type": ["string", "null"],
"description": (
"Message shown to the user when query is null. "
"Falls back to a generic message if not provided."
),
},
},
"required": ["query"],
}

View File

@@ -1,5 +0,0 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -454,7 +453,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -1,453 +0,0 @@
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import create_hook__no_commit
from onyx.db.hook import delete_hook__no_commit
from onyx.db.hook import get_hook_by_id
from onyx.db.hook import get_hook_execution_logs
from onyx.db.hook import get_hooks
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.api_dependencies import require_hook_enabled
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookExecutionRecord
from onyx.hooks.models import HookPointMetaResponse
from onyx.hooks.models import HookResponse
from onyx.hooks.models import HookUpdateRequest
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
logger = setup_logger()
# ---------------------------------------------------------------------------
# SSRF protection
# ---------------------------------------------------------------------------
def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
return HookResponse(
id=hook.id,
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
is_reachable=hook.is_reachable,
creator_email=(
creator_email
if creator_email is not None
else (hook.creator.email if hook.creator else None)
),
created_at=hook.created_at,
updated_at=hook.updated_at,
)
def _get_hook_or_404(
db_session: Session,
hook_id: int,
include_creator: bool = False,
) -> Hook:
hook = get_hook_by_id(
db_session=db_session,
hook_id=hook_id,
include_creator=include_creator,
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
return hook
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
"""Raise an appropriate OnyxError for a non-passed validation result."""
if validation.status == HookValidateStatus.auth_failed:
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
if validation.status == HookValidateStatus.timeout:
raise OnyxError(
OnyxErrorCode.GATEWAY_TIMEOUT,
f"Endpoint validation failed: {validation.error_message}",
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Endpoint validation failed: {validation.error_message}",
)
def _validate_endpoint(
endpoint_url: str,
api_key: str | None,
timeout_seconds: float,
) -> HookValidateResponse:
"""Check whether endpoint_url is reachable by sending an empty POST request.
We use POST since hook endpoints expect POST requests. The server will typically
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
the server is up and routable. A 401/403 response returns auth_failed
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
response = client.post(endpoint_url, headers=headers)
if response.status_code in (401, 403):
return HookValidateResponse(
status=HookValidateStatus.auth_failed,
error_message=f"Authentication failed (HTTP {response.status_code})",
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.timeout,
error_message="Endpoint timed out — consider increasing timeout_seconds.",
)
except Exception as exc:
logger.warning(
"Hook endpoint validation: connection error for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# ---------------------------------------------------------------------------
# Routers
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/admin/hooks")
# ---------------------------------------------------------------------------
# Hook endpoints
# ---------------------------------------------------------------------------
@router.get("/specs")
def get_hook_point_specs(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
) -> list[HookPointMetaResponse]:
return [
HookPointMetaResponse(
hook_point=spec.hook_point,
display_name=spec.display_name,
description=spec.description,
docs_url=spec.docs_url,
input_schema=spec.input_schema,
output_schema=spec.output_schema,
default_timeout_seconds=spec.default_timeout_seconds,
default_fail_strategy=spec.default_fail_strategy,
fail_hard_description=spec.fail_hard_description,
)
for spec in get_all_specs()
]
@router.get("")
def list_hooks(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookResponse]:
hooks = get_hooks(db_session=db_session, include_creator=True)
return [_hook_to_response(h) for h in hooks]
@router.post("")
def create_hook(
req: HookCreateRequest,
user: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
endpoint_url=req.endpoint_url,
api_key=api_key,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
hook = create_hook__no_commit(
db_session=db_session,
name=req.name,
hook_point=req.hook_point,
endpoint_url=req.endpoint_url,
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)
@router.get("/{hook_id}")
def get_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
return _hook_to_response(hook)
@router.patch("/{hook_id}")
def update_hook(
hook_id: int,
req: HookUpdateRequest,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
endpoint is re-validated using the effective values. For active hooks the update is
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
the update goes through regardless and is_reachable is updated to reflect the result.
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
"""
# api_key: UNSET = no change, None = clear, value = update
api_key: str | None | UnsetType
if "api_key" not in req.model_fields_set:
api_key = UNSET
elif req.api_key is None:
api_key = None
else:
api_key = req.api_key.get_secret_value()
endpoint_url_changing = "endpoint_url" in req.model_fields_set
api_key_changing = not isinstance(api_key, UnsetType)
timeout_changing = "timeout_seconds" in req.model_fields_set
validated_is_reachable: bool | None = None
if endpoint_url_changing or api_key_changing or timeout_changing:
existing = _get_hook_or_404(db_session, hook_id)
effective_url: str = (
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
)
effective_api_key: str | None = (
(api_key if not isinstance(api_key, UnsetType) else None)
if api_key_changing
else (
existing.api_key.get_value(apply_mask=False)
if existing.api_key
else None
)
)
effective_timeout: float = (
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
)
validation = _validate_endpoint(
endpoint_url=effective_url,
api_key=effective_api_key,
timeout_seconds=effective_timeout,
)
if existing.is_active and validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
validated_is_reachable = validation.status == HookValidateStatus.passed
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
name=req.name,
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
api_key=api_key,
fail_strategy=req.fail_strategy,
timeout_seconds=req.timeout_seconds,
is_reachable=validated_is_reachable,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.delete("/{hook_id}")
def delete_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> None:
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
db_session.commit()
@router.post("/{hook_id}/activate")
def activate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
# Persist is_reachable=False in a separate session so the request
# session has no commits on the failure path and the transaction
# boundary stays clean.
if hook.is_reachable is not False:
with get_session_with_current_tenant() as side_session:
update_hook__no_commit(
db_session=side_session, hook_id=hook_id, is_reachable=False
)
side_session.commit()
_raise_for_validation_failure(validation)
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=True,
is_reachable=True,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.post("/{hook_id}/validate")
def validate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookValidateResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
validation_passed = validation.status == HookValidateStatus.passed
if hook.is_reachable != validation_passed:
update_hook__no_commit(
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
)
db_session.commit()
return validation
@router.post("/{hook_id}/deactivate")
def deactivate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=False,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
# ---------------------------------------------------------------------------
# Execution log endpoints
# ---------------------------------------------------------------------------
@router.get("/{hook_id}/execution-logs")
def list_hook_execution_logs(
hook_id: int,
limit: int = Query(default=10, ge=1, le=100),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookExecutionRecord]:
_get_hook_or_404(db_session, hook_id)
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
return [
HookExecutionRecord(
error_message=log.error_message,
status_code=log.status_code,
duration_ms=log.duration_ms,
created_at=log.created_at,
)
for log in logs
]

View File

@@ -53,12 +53,8 @@ logger = setup_logger()
class SearchToolConfig(BaseModel):
user_selected_filters: BaseFilters | None = None
# Vespa metadata filters for overflowing user files. These are NOT the
# IDs of the current project/persona — they are only set when the
# project's/persona's user files didn't fit in the LLM context window and
# must be found via vector DB search instead.
project_id_filter: int | None = None
persona_id_filter: int | None = None
project_id: int | None = None
persona_id: int | None = None
bypass_acl: bool = False
additional_context: str | None = None
slack_context: SlackContext | None = None
@@ -184,8 +180,8 @@ def construct_tools(
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
project_id=search_tool_config.project_id,
persona_id=search_tool_config.persona_id,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
@@ -432,8 +428,8 @@ def construct_tools(
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
project_id=search_tool_config.project_id,
persona_id=search_tool_config.persona_id,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,

View File

@@ -764,7 +764,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
tags=None,
access_control_list=access_control_list,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
project_id_filter=None,
user_file_ids=None,
project_id=None,
)
def _merge_indexed_and_crawled_results(

View File

@@ -244,11 +244,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_index: DocumentIndex,
# Respecting user selections
user_selected_filters: BaseFilters | None,
# Vespa metadata filters for overflowing user files. NOT the raw IDs
# of the current project/persona — only set when user files couldn't
# fit in the LLM context and need to be searched via vector DB.
project_id_filter: int | None,
persona_id_filter: int | None = None,
# If the chat is part of a project
project_id: int | None,
# If set, search scopes to files attached to this persona
persona_id: int | None = None,
bypass_acl: bool = False,
# Slack context for federated Slack search (tokens fetched internally)
slack_context: SlackContext | None = None,
@@ -262,8 +261,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self.llm = llm
self.document_index = document_index
self.user_selected_filters = user_selected_filters
self.project_id_filter = project_id_filter
self.persona_id_filter = persona_id_filter
self.project_id = project_id
self.persona_id = persona_id
self.bypass_acl = bypass_acl
self.slack_context = slack_context
self.enable_slack_search = enable_slack_search
@@ -452,15 +451,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
hybrid_alpha=hybrid_alpha,
# For projects, the search scope is the project and has no other limits
user_selected_filters=(
self.user_selected_filters
if self.project_id_filter is None
else None
self.user_selected_filters if self.project_id is None else None
),
bypass_acl=self.bypass_acl,
limit=num_hits,
),
project_id_filter=self.project_id_filter,
persona_id_filter=self.persona_id_filter,
project_id=self.project_id,
persona_id=self.persona_id,
document_index=self.document_index,
user=self.user,
persona=self.persona,
@@ -577,7 +574,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
# Federated retrieval functions (non-Slack; Slack is separate)
if self.project_id_filter is not None:
if self.project_id is not None:
# Project mode ignores user filters → no federated sources
prefetch_source_types = None
else:
@@ -590,12 +587,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
persona_document_sets = (
[ds.name for ds in self.persona.document_sets] if self.persona else None
)
user_file_ids = (
[uf.id for uf in self.persona.user_files] if self.persona else None
)
federated_retrieval_infos = (
get_federated_retrieval_functions(
db_session=db_session,
user_id=self.user.id if self.user else None,
source_types=prefetch_source_types,
document_set_names=persona_document_sets,
user_file_ids=user_file_ids,
)
or []
)

View File

@@ -140,20 +140,10 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
return validated_ip, hostname, port
def validate_outbound_http_url(
url: str,
*,
allow_private_network: bool = False,
https_only: bool = False,
) -> str:
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
"""
Validate a URL that will be used by backend outbound HTTP calls.
Args:
url: The URL to validate.
allow_private_network: If True, skip private/reserved IP checks.
https_only: If True, reject http:// URLs (only https:// is allowed).
Returns:
A normalized URL string with surrounding whitespace removed.
@@ -167,12 +157,7 @@ def validate_outbound_http_url(
parsed = urlparse(normalized_url)
if https_only:
if parsed.scheme != "https":
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
)
elif parsed.scheme not in ("http", "https"):
if parsed.scheme not in ("http", "https"):
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
)

View File

@@ -1,37 +0,0 @@
from sqlalchemy import inspect
from sqlalchemy.orm import Session
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.models import Persona
def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
"""Verify that eager_load_persona pre-loads persona, its collections, and project."""
persona = Persona(name="eager-load-test", description="test")
db_session.add(persona)
db_session.flush()
chat_session = create_chat_session(
db_session=db_session,
description="test",
user_id=None,
persona_id=persona.id,
)
loaded = get_chat_session_by_id(
chat_session_id=chat_session.id,
user_id=None,
db_session=db_session,
eager_load_persona=True,
)
unloaded = inspect(loaded).unloaded
assert "persona" not in unloaded
assert "project" not in unloaded
persona_unloaded = inspect(loaded.persona).unloaded
assert "tools" not in persona_unloaded
assert "user_files" not in persona_unloaded
db_session.rollback()

View File

@@ -1,30 +1,34 @@
"""Tests for OpenSearch assistant knowledge filter construction.
These tests verify that when an assistant (persona) has knowledge attached,
the search filter includes the appropriate scope filters with OR logic (not AND),
ensuring documents are discoverable across knowledge types like attached documents,
hierarchy nodes, document sets, and persona/project user files.
These tests verify that when an assistant (persona) has user files attached,
the search filter includes those user file IDs in the assistant knowledge filter
with OR logic (not AND), ensuring user files are discoverable alongside other
knowledge types like attached documents and hierarchy nodes.
This prevents a regression where user_file_ids were added as a separate AND
filter, making it impossible to find user files when the assistant also had
attached documents or hierarchy nodes (since no document could match both).
"""
from typing import Any
from uuid import UUID
from onyx.configs.constants import DocumentSource
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
USER_FILE_ID = UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b")
ATTACHED_DOCUMENT_ID = "https://docs.google.com/document/d/test-doc-id"
HIERARCHY_NODE_ID = 42
PERSONA_ID = 7
def _get_search_filters(
source_types: list[DocumentSource],
user_file_ids: list[UUID],
attached_document_ids: list[str] | None,
hierarchy_node_ids: list[int] | None,
persona_id_filter: int | None = None,
document_sets: list[str] | None = None,
) -> list[dict[str, Any]]:
return DocumentQuery._get_search_filters(
tenant_state=TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False),
@@ -32,14 +36,15 @@ def _get_search_filters(
access_control_list=["user_email:test@example.com"],
source_types=source_types,
tags=[],
document_sets=document_sets or [],
project_id_filter=None,
persona_id_filter=persona_id_filter,
document_sets=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
max_chunk_size=None,
document_id=None,
user_file_ids=user_file_ids,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
)
@@ -48,97 +53,137 @@ def _get_search_filters(
class TestAssistantKnowledgeFilter:
"""Tests for assistant knowledge filter construction in OpenSearch queries."""
def test_persona_id_filter_added_when_knowledge_scope_exists(self) -> None:
"""persona_id_filter should be OR'd into the knowledge scope filter
when explicit knowledge attachments (attached_document_ids,
hierarchy_node_ids, document_sets) are present."""
def test_user_file_ids_included_in_assistant_knowledge_filter(self) -> None:
"""
Tests that user_file_ids are included in the assistant knowledge filter
with OR logic when the assistant has both user files and attached documents.
This prevents the regression where user files were ANDed with other
knowledge types, making them unfindable.
"""
# Under test: Call the filter construction method directly
filter_clauses = _get_search_filters(
source_types=[DocumentSource.FILE],
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
user_file_ids=[USER_FILE_ID],
attached_document_ids=[ATTACHED_DOCUMENT_ID],
hierarchy_node_ids=[HIERARCHY_NODE_ID],
persona_id_filter=PERSONA_ID,
)
# Postcondition: Find the assistant knowledge filter (bool with should clauses)
knowledge_filter = None
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
# Check if this is the knowledge filter (has minimum_should_match=1)
if clause["bool"].get("minimum_should_match") == 1:
knowledge_filter = clause
break
assert knowledge_filter is not None, (
"Expected to find an assistant knowledge filter with "
"'minimum_should_match: 1'"
)
assert (
knowledge_filter is not None
), "Expected to find an assistant knowledge filter with 'minimum_should_match: 1'"
# The knowledge filter should have 3 should clauses (user files, attached docs, hierarchy nodes)
should_clauses = knowledge_filter["bool"]["should"]
persona_found = any(
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
== PERSONA_ID
for clause in should_clauses
)
assert persona_found, (
f"Expected persona_id={PERSONA_ID} filter on {PERSONAS_FIELD_NAME} "
f"in should clauses. Got: {should_clauses}"
assert (
len(should_clauses) == 3
), f"Expected 3 should clauses (user_file, attached_doc, hierarchy_node), got {len(should_clauses)}"
# Verify user_file_id is in one of the should clauses
user_file_filter_found = False
for should_clause in should_clauses:
# The user file filter uses a nested bool with should for each file ID
if "bool" in should_clause and "should" in should_clause["bool"]:
for term_clause in should_clause["bool"]["should"]:
if "term" in term_clause:
term_value = term_clause["term"].get(DOCUMENT_ID_FIELD_NAME, {})
if term_value.get("value") == str(USER_FILE_ID):
user_file_filter_found = True
break
assert user_file_filter_found, (
f"Expected user_file_id {USER_FILE_ID} to be in the assistant knowledge "
f"filter's should clauses. Filter structure: {knowledge_filter}"
)
def test_persona_id_filter_alone_creates_knowledge_scope(self) -> None:
"""persona_id_filter IS a primary knowledge scope trigger — a persona
with user files is explicit knowledge, so it should restrict
search on its own."""
def test_user_file_ids_only_creates_knowledge_filter(self) -> None:
"""
Tests that when only user_file_ids are provided (no attached_documents or
hierarchy_nodes), the assistant knowledge filter is still created with the
user file IDs.
"""
# Precondition
filter_clauses = _get_search_filters(
source_types=[],
source_types=[DocumentSource.USER_FILE],
user_file_ids=[USER_FILE_ID],
attached_document_ids=None,
hierarchy_node_ids=None,
persona_id_filter=PERSONA_ID,
)
knowledge_filter = None
# Postcondition: Find filter that contains our user file ID
user_file_filter_found = False
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
knowledge_filter = clause
break
clause_str = str(clause)
if str(USER_FILE_ID) in clause_str:
user_file_filter_found = True
break
assert (
knowledge_filter is not None
), "Expected persona_id_filter alone to create a knowledge scope filter"
persona_found = any(
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
== PERSONA_ID
for clause in knowledge_filter["bool"]["should"]
)
assert persona_found, (
f"Expected persona_id={PERSONA_ID} filter in knowledge scope. "
f"Got: {knowledge_filter}"
)
user_file_filter_found
), f"Expected user_file_id {USER_FILE_ID} to be in the filter clauses. Got: {filter_clauses}"
def test_no_separate_user_file_filter_when_assistant_has_knowledge(self) -> None:
"""
Tests that user_file_ids are NOT added as a separate AND filter when the
assistant has other knowledge attached (attached_documents or hierarchy_nodes).
"""
def test_knowledge_filter_with_document_sets_and_persona_filter(self) -> None:
"""document_sets and persona_id_filter should be OR'd together in
the knowledge scope filter."""
filter_clauses = _get_search_filters(
source_types=[],
attached_document_ids=None,
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
user_file_ids=[USER_FILE_ID],
attached_document_ids=[ATTACHED_DOCUMENT_ID],
hierarchy_node_ids=None,
persona_id_filter=PERSONA_ID,
document_sets=["engineering"],
)
knowledge_filter = None
# Postcondition: Count how many times user_file_id appears in filter clauses
# It should appear exactly once (in the knowledge filter), not twice
user_file_id_str = str(USER_FILE_ID)
occurrences = 0
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
knowledge_filter = clause
break
if user_file_id_str in str(clause):
occurrences += 1
assert (
knowledge_filter is not None
), "Expected knowledge filter when document_sets is provided"
assert occurrences == 1, (
f"Expected user_file_id to appear exactly once in filter clauses "
f"(inside the assistant knowledge filter), but found {occurrences} "
f"occurrences. This suggests user_file_ids is being added as both a "
f"separate AND filter and inside the knowledge filter. "
f"Filter clauses: {filter_clauses}"
)
filter_str = str(knowledge_filter)
assert (
"engineering" in filter_str
), "Expected document_set 'engineering' in knowledge filter"
assert (
str(PERSONA_ID) in filter_str
), f"Expected persona_id_filter {PERSONA_ID} in knowledge filter"
def test_multiple_user_files_all_included_in_filter(self) -> None:
"""
Tests that when multiple user files are attached to an assistant,
all of them are included in the filter.
"""
# Precondition
user_file_ids = [
UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b"),
UUID("7be95f56-5561-517d-ae47-acd6f85bdb7c"),
UUID("8cf06a67-6672-628e-bf58-ade7a96cec8d"),
]
filter_clauses = _get_search_filters(
source_types=[DocumentSource.USER_FILE],
user_file_ids=user_file_ids,
attached_document_ids=[ATTACHED_DOCUMENT_ID],
hierarchy_node_ids=None,
)
# Postcondition: All user file IDs should be in the filter
filter_str = str(filter_clauses)
for user_file_id in user_file_ids:
assert (
str(user_file_id) in filter_str
), f"Expected user_file_id {user_file_id} to be in the filter clauses"

View File

@@ -14,7 +14,6 @@ from __future__ import annotations
import os
import subprocess
import sys
import time
import uuid
from collections.abc import Generator
@@ -29,9 +28,6 @@ _BACKEND_DIR = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
_DROP_SCHEMA_MAX_RETRIES = 3
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
# ---------------------------------------------------------------------------
# Helpers
@@ -54,39 +50,6 @@ def _run_script(
)
def _force_drop_schema(engine: Engine, schema: str) -> None:
"""Terminate backends using *schema* then drop it, retrying on deadlock.
Background Celery workers may discover test schemas (they match the
``tenant_`` prefix) and hold locks on tables inside them. A bare
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
first kill their connections and retry if we still hit a deadlock.
"""
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
try:
with engine.connect() as conn:
conn.execute(
text(
"""
SELECT pg_terminate_backend(l.pid)
FROM pg_locks l
JOIN pg_class c ON c.oid = l.relation
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND l.pid != pg_backend_pid()
"""
),
{"schema": schema},
)
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
return
except Exception:
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
raise
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -141,7 +104,9 @@ def tenant_schema_at_head(
yield schema
_force_drop_schema(engine, schema)
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
@@ -158,7 +123,9 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
yield schema
_force_drop_schema(engine, schema)
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
@@ -183,7 +150,9 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
yield schema
_force_drop_schema(engine, schema)
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
# ---------------------------------------------------------------------------

View File

@@ -1,5 +1,3 @@
import csv
import io
import os
from datetime import datetime
from datetime import timedelta
@@ -141,12 +139,12 @@ def test_chat_history_csv_export(
assert headers["Content-Type"] == "text/csv; charset=utf-8"
assert "Content-Disposition" in headers
# Use csv.reader to properly handle newlines inside quoted fields
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 3 # Header + 2 QA pairs
assert csv_rows[0][0] == "chat_session_id"
assert "user_message" in csv_rows[0]
assert "ai_response" in csv_rows[0]
# Verify CSV content
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 3 # Header + 2 QA pairs
assert "chat_session_id" in csv_content
assert "user_message" in csv_content
assert "ai_response" in csv_content
assert "What was the Q1 revenue?" in csv_content
assert "What about Q2 revenue?" in csv_content
@@ -158,5 +156,5 @@ def test_chat_history_csv_export(
end_time=past_end,
user_performing_action=admin_user,
)
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 1 # Only header, no data rows
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 1 # Only header, no data rows

View File

@@ -1,5 +1,6 @@
from typing import Any
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
@@ -10,10 +11,12 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
class IncompleteSpec(HookPointSpec):
hook_point = HookPoint.QUERY_PROCESSING
# missing display_name, description, payload_model, response_model, etc.
# missing display_name, description, etc.
class _Payload(BaseModel):
pass
@property
def input_schema(self) -> dict[str, Any]:
return {}
payload_model = _Payload
response_model = _Payload
@property
def output_schema(self) -> dict[str, Any]:
return {}

View File

@@ -1,541 +0,0 @@
"""Unit tests for the hook executor."""
import json
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
def _make_hook(
*,
is_active: bool = True,
endpoint_url: str | None = "https://hook.example.com/query",
api_key: MagicMock | None = None,
timeout_seconds: float = 5.0,
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
hook.endpoint_url = endpoint_url
hook.api_key = api_key
hook.timeout_seconds = timeout_seconds
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
return hook
def _make_api_key(value: str) -> MagicMock:
api_key = MagicMock()
api_key.get_value.return_value = value
return api_key
def _make_response(
*,
status_code: int = 200,
json_return: Any = _RESPONSE_PAYLOAD,
json_side_effect: Exception | None = None,
) -> MagicMock:
"""Build a response mock with controllable json() behaviour."""
response = MagicMock()
response.status_code = status_code
if json_side_effect is not None:
response.json.side_effect = json_side_effect
else:
response.json.return_value = json_return
return response
def _setup_client(
mock_client_cls: MagicMock,
*,
response: MagicMock | None = None,
side_effect: Exception | None = None,
) -> MagicMock:
"""Wire up the httpx.Client mock and return the inner client.
If side_effect is an httpx.HTTPStatusError, it is raised from
raise_for_status() (matching real httpx behaviour) and post() returns a
response mock with the matching status_code set. All other exceptions are
raised directly from post().
"""
mock_client = MagicMock()
if isinstance(side_effect, httpx.HTTPStatusError):
error_response = MagicMock()
error_response.status_code = side_effect.response.status_code
error_response.raise_for_status.side_effect = side_effect
mock_client.post = MagicMock(return_value=error_response)
else:
mock_client.post = MagicMock(
side_effect=side_effect, return_value=response if not side_effect else None
)
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def db_session() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# Early-exit guards (no HTTP call, no DB writes)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hooks_available,hook",
[
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
hooks_available: bool,
hook: MagicMock | None,
) -> None:
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSkipped)
mock_update.assert_not_called()
mock_log.assert_not_called()
# ---------------------------------------------------------------------------
# Successful HTTP call
# ---------------------------------------------------------------------------
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
hook = _make_hook()
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
"""Deduplication guard: a hook already at is_reachable=True that succeeds
must not trigger a DB write."""
hook = _make_hook(is_reachable=True)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
mock_update.assert_not_called()
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(json_return=["unexpected", "list"]),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-dict" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
"""response.json() raising must be treated as failure with SOFT strategy.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-JSON" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
# ---------------------------------------------------------------------------
# HTTP failure paths
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"exception,fail_strategy,expected_type,expected_is_reachable",
[
# NetworkError → is_reachable=False
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="connect_error_soft",
),
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.HARD,
OnyxError,
False,
id="connect_error_hard",
),
# 401/403 → is_reachable=False (api_key revoked)
pytest.param(
httpx.HTTPStatusError(
"401",
request=MagicMock(),
response=MagicMock(status_code=401, text="Unauthorized"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="auth_401_soft",
),
pytest.param(
httpx.HTTPStatusError(
"403",
request=MagicMock(),
response=MagicMock(status_code=403, text="Forbidden"),
),
HookFailStrategy.HARD,
OnyxError,
False,
id="auth_403_hard",
),
# TimeoutException → no is_reachable write (None)
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="timeout_soft",
),
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.HARD,
OnyxError,
None,
id="timeout_hard",
),
# Other HTTP errors → no is_reachable write (None)
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="http_status_error_soft",
),
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.HARD,
OnyxError,
None,
id="http_status_error_hard",
),
],
)
def test_http_failure_paths(
db_session: MagicMock,
exception: Exception,
fail_strategy: HookFailStrategy,
expected_type: type,
expected_is_reachable: bool | None,
) -> None:
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, expected_type)
if expected_is_reachable is None:
mock_update.assert_not_called()
else:
mock_update.assert_called_once()
_, kwargs = mock_update.call_args
assert kwargs["is_reachable"] is expected_is_reachable
# ---------------------------------------------------------------------------
# Authorization header
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"api_key_value,expect_auth_header",
[
pytest.param("secret-token", True, id="api_key_present"),
pytest.param(None, False, id="api_key_absent"),
],
)
def test_authorization_header(
db_session: MagicMock,
api_key_value: str | None,
expect_auth_header: bool,
) -> None:
api_key = _make_api_key(api_key_value) if api_key_value else None
hook = _make_hook(api_key=api_key)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
_, call_kwargs = mock_client.post.call_args
if expect_auth_header:
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
else:
assert "Authorization" not in call_kwargs["headers"]
# ---------------------------------------------------------------------------
# Persist session failure
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"http_exception,expected_result",
[
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expected_result: Any,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response() if not http_exception else None,
side_effect=http_exception,
)
if expected_result is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == expected_result
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
"""is_reachable update failing (e.g. concurrent hook deletion) must not
prevent the execution log from being written.
Simulates the production failure path: update_hook__no_commit raises
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
between the initial lookup and the reachable update.
"""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
mock_log.assert_called_once()

View File

@@ -37,20 +37,18 @@ def test_input_schema_query_is_string() -> None:
def test_input_schema_user_email_is_nullable() -> None:
props = QueryProcessingSpec().input_schema["properties"]
# Pydantic v2 emits anyOf for nullable fields
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
assert "null" in props["user_email"]["type"]
def test_output_schema_query_is_optional() -> None:
# query defaults to None (absent = reject); not required in the schema
def test_output_schema_query_is_required() -> None:
schema = QueryProcessingSpec().output_schema
assert "query" not in schema.get("required", [])
assert "query" in schema["required"]
def test_output_schema_query_is_nullable() -> None:
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
# null means "reject the query"
props = QueryProcessingSpec().output_schema["properties"]
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
assert "null" in props["query"]["type"]
def test_output_schema_rejection_message_is_optional() -> None:

View File

@@ -1,278 +0,0 @@
"""Unit tests for onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_URL = "https://example.com/hook"
_API_KEY = "secret"
_TIMEOUT = 5.0
def _mock_response(status_code: int) -> MagicMock:
response = MagicMock()
response.status_code = status_code
return response
# ---------------------------------------------------------------------------
# _check_ssrf_safety
# ---------------------------------------------------------------------------
class TestCheckSsrfSafety:
def _call(self, url: str) -> None:
_check_ssrf_safety(url)
# --- scheme checks ---
def test_https_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
@pytest.mark.parametrize(
"url", ["http://example.com/hook", "ftp://example.com/hook"]
)
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@pytest.mark.parametrize(
"ip",
[
pytest.param("127.0.0.1", id="loopback"),
pytest.param("10.0.0.1", id="RFC1918-A"),
pytest.param("172.16.0.1", id="RFC1918-B"),
pytest.param("192.168.1.1", id="RFC1918-C"),
pytest.param("169.254.169.254", id="link-local-IMDS"),
pytest.param("100.64.0.1", id="shared-address-space"),
pytest.param("::1", id="IPv6-loopback"),
pytest.param("fc00::1", id="IPv6-ULA"),
pytest.param("fe80::1", id="IPv6-link-local"),
],
)
def test_private_ip_is_blocked(self, ip: str) -> None:
with (
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
pytest.raises(OnyxError) as exc_info,
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
def test_dns_resolution_failure_raises(self) -> None:
import socket
with (
patch(
"onyx.utils.url.socket.getaddrinfo",
side_effect=socket.gaierror("name not found"),
),
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
# _validate_endpoint
# ---------------------------------------------------------------------------
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(status_code)
)
result = self._call()
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
httpx.ReadTimeout("read timeout"),
httpx.WriteTimeout("write timeout"),
httpx.PoolTimeout("pool timeout"),
],
)
def test_read_write_pool_timeout_returns_timeout(
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
# Covers DNS failures, TLS errors, and other connection-level errors.
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectError("name resolution failed")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
ConnectionRefusedError("refused")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key="mykey")
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key=None)
_, kwargs = mock_post.call_args
assert "Authorization" not in kwargs["headers"]
# ---------------------------------------------------------------------------
# _raise_for_validation_failure
# ---------------------------------------------------------------------------
class TestRaiseForValidationFailure:
@pytest.mark.parametrize(
"status, expected_code",
[
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
],
)
def test_raises_correct_error_code(
self, status: HookValidateStatus, expected_code: OnyxErrorCode
) -> None:
validation = HookValidateResponse(status=status, error_message="some error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.error_code == expected_code
def test_auth_failed_passes_error_message_directly(self) -> None:
validation = HookValidateResponse(
status=HookValidateStatus.auth_failed, error_message="bad credentials"
)
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "bad credentials"
@pytest.mark.parametrize(
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
)
def test_timeout_and_cannot_connect_wrap_error_message(
self, status: HookValidateStatus
) -> None:
validation = HookValidateResponse(status=status, error_message="raw error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "Endpoint validation failed: raw error"
# ---------------------------------------------------------------------------
# HookValidateStatus enum string values (API contract)
# ---------------------------------------------------------------------------
class TestHookValidateStatusValues:
@pytest.mark.parametrize(
"status, expected",
[
(HookValidateStatus.passed, "passed"),
(HookValidateStatus.auth_failed, "auth_failed"),
(HookValidateStatus.timeout, "timeout"),
(HookValidateStatus.cannot_connect, "cannot_connect"),
],
)
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
assert status == expected

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import UUID
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
@@ -10,10 +11,10 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_PROJECT
@@ -150,30 +151,56 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_project_filter(self) -> None:
"""Test user project filtering.
def test_user_file_ids_filter(self) -> None:
"""Test user file IDs filtering."""
id1 = UUID("00000000-0000-0000-0000-000000000123")
id2 = UUID("00000000-0000-0000-0000-000000000456")
project_id_filter alone does NOT trigger a knowledge scope restriction
(an agent with no explicit knowledge should search everything).
It only participates when explicit knowledge filters are present.
"""
# project_id_filter alone → no restriction
filters = IndexFilters(access_control_list=[], project_id_filter=789)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
# project_id_filter with document_set → both OR'd
filters = IndexFilters(
access_control_list=[], project_id_filter=789, document_set=["set1"]
)
# Single user file ID (UUID)
filters = IndexFilters(access_control_list=[], user_file_ids=[id1])
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "set1") or ({USER_PROJECT} contains "789")) and '
f'!({HIDDEN}=true) and ({DOCUMENT_ID} contains "{str(id1)}") and ' == result
)
# Multiple user file IDs (UUIDs)
filters = IndexFilters(access_control_list=[], user_file_ids=[id1, id2])
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({DOCUMENT_ID} contains "{str(id1)}" or {DOCUMENT_ID} contains "{str(id2)}") and '
== result
)
# No project id filter
filters = IndexFilters(access_control_list=[], project_id_filter=None)
# Empty user file IDs
filters = IndexFilters(access_control_list=[], user_file_ids=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_project_filter(self) -> None:
"""Test user project filtering.
project_id alone does NOT trigger a knowledge scope restriction
(an agent with no explicit knowledge should search everything).
It only participates when explicit knowledge filters are present.
"""
# project_id alone → no restriction
filters = IndexFilters(access_control_list=[], project_id=789)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
# project_id with user_file_ids → both OR'd
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=[], project_id=789, user_file_ids=[id1]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
== result
)
# No project id
filters = IndexFilters(access_control_list=[], project_id=None)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
@@ -206,16 +233,17 @@ class TestBuildVespaFilters:
def test_combined_filters(self) -> None:
"""Test combining multiple filter types.
Knowledge-scope filters (document_set, project_id_filter, persona_id_filter)
are OR'd together, while all other filters are AND'd.
Knowledge-scope filters (document_set, user_file_ids, project_id,
persona_id) are OR'd together, while all other filters are AND'd.
"""
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=["user1", "group1"],
source_type=[DocumentSource.WEB],
tags=[Tag(tag_key="color", tag_value="red")],
document_set=["set1"],
project_id_filter=789,
persona_id_filter=42,
user_file_ids=[id1],
project_id=789,
time_cutoff=datetime(2023, 1, 1, tzinfo=timezone.utc),
)
@@ -226,10 +254,9 @@ class TestBuildVespaFilters:
expected += f'({SOURCE_TYPE} contains "web") and '
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
# Knowledge scope filters are OR'd together
# (persona_id_filter is primary, project_id_filter is additive — order reflects this)
expected += (
f'(({DOCUMENT_SETS} contains "set1")'
f' or ({PERSONAS} contains "42")'
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
f' or ({USER_PROJECT} contains "789")'
f") and "
)
@@ -249,37 +276,18 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
def test_persona_id_filter_is_primary_knowledge_scope(self) -> None:
"""persona_id_filter alone should trigger a knowledge scope restriction
(a persona with user files IS explicit knowledge)."""
filters = IndexFilters(access_control_list=[], persona_id_filter=42)
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({PERSONAS} contains "42") and ' == result
def test_persona_id_filter_with_project_id_filter(self) -> None:
"""When persona_id_filter triggers the scope, project_id_filter should be
OR'd in additively."""
filters = IndexFilters(
access_control_list=[], persona_id_filter=42, project_id_filter=789
)
result = build_vespa_filters(filters)
expected = (
f"!({HIDDEN}=true) and "
f'(({PERSONAS} contains "42") or ({USER_PROJECT} contains "789")) and '
)
assert expected == result
def test_knowledge_scope_document_set_and_persona_filter_ored(self) -> None:
"""Document set filter and persona_id_filter must be OR'd so that
connector documents (in the set) and persona user files can
both be found."""
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
"""Document set filter and user file IDs must be OR'd so that
connector documents (in the set) and user files (with specific
IDs) can both be found."""
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=[],
document_set=["engineering"],
persona_id_filter=42,
user_file_ids=[id1],
)
result = build_vespa_filters(filters)
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({PERSONAS} contains "42")) and '
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({DOCUMENT_ID} contains "{str(id1)}")) and '
assert expected == result
def test_acl_large_list_uses_weighted_set(self) -> None:

View File

@@ -489,18 +489,20 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/nginx-templates:ro
- ../data/nginx:/etc/nginx/conf.d
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -290,20 +290,25 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/nginx-templates:ro
- ../data/nginx:/etc/nginx/conf.d
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -314,19 +314,21 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/nginx-templates:ro
- ../data/nginx:/etc/nginx/conf.d
- ../data/sslcerts:/etc/nginx/sslcerts
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
env_file:
- .env.nginx
environment:

View File

@@ -333,20 +333,25 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/nginx-templates:ro
- ../data/nginx:/etc/nginx/conf.d
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -202,18 +202,20 @@ services:
ports:
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/nginx-templates:ro
- ../data/nginx:/etc/nginx/conf.d
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -477,10 +477,7 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
# Mount templates read-only; the startup command copies them into
# the writable /etc/nginx/conf.d/ inside the container. This avoids
# "Permission denied" errors on Windows Docker bind mounts.
- ../data/nginx:/nginx-templates:ro
- ../data/nginx:/etc/nginx/conf.d
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
# - ../data/certbot/conf:/etc/letsencrypt
# - ../data/certbot/www:/var/www/certbot
@@ -492,13 +489,12 @@ services:
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not receive any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
# PRODUCTION: Change to app.conf.template.prod for production nginx config
command: >
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
cache:
image: redis:7.4-alpine

View File

@@ -2,7 +2,7 @@
# Usage: .\install.ps1 [OPTIONS]
# Remote (with params):
# & ([scriptblock]::Create((irm https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.ps1))) -Lite -NoPrompt
# Remote (defaults only, configure via interaction during script):
# Remote (defaults only):
# irm https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.ps1 | iex
param(
@@ -57,7 +57,11 @@ function Print-Step {
}
function Test-Interactive {
return -not $NoPrompt
if ($NoPrompt) { return $false }
try {
if ([Console]::IsInputRedirected) { return $false }
return $true
} catch { return [Environment]::UserInteractive }
}
function Prompt-OrDefault {
@@ -70,8 +74,8 @@ function Prompt-OrDefault {
function Confirm-Action {
param([string]$Description)
$reply = (Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y").Trim().ToLower()
if ($reply -match '^n') {
$reply = Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y"
if ($reply -match '^[Nn]') {
Print-Warning "Skipping: $Description"
return $false
}
@@ -85,12 +89,12 @@ function Prompt-VersionTag {
Write-Host " - Type a specific tag (e.g., craft-v1.0.0)"
$version = Prompt-OrDefault "Enter tag [default: craft-latest]" "craft-latest"
} else {
Write-Host " - Press Enter for edge (recommended)"
Write-Host " - Press Enter for latest (recommended)"
Write-Host " - Type a specific tag (e.g., v0.1.0)"
$version = Prompt-OrDefault "Enter tag [default: edge]" "edge"
$version = Prompt-OrDefault "Enter tag [default: latest]" "latest"
}
if ($script:IncludeCraftMode -and $version -eq "craft-latest") { Print-Info "Selected: craft-latest (Craft enabled)" }
elseif ($version -eq "edge") { Print-Info "Selected: edge (latest nightly)" }
elseif ($version -eq "latest") { Print-Info "Selected: Latest tag" }
else { Print-Info "Selected: $version" }
return $version
}
@@ -99,16 +103,16 @@ function Prompt-DeploymentMode {
param([string]$LiteOverlayPath)
if ($script:LiteMode) { Print-Info "Deployment mode: Lite (set via -Lite flag)"; return }
Print-Info "Which deployment mode would you like?"
Write-Host " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
Write-Host " 1) Standard - Full deployment with search, connectors, and RAG"
Write-Host " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
Write-Host " LLM chat, tools, file uploads, and Projects still work"
Write-Host " 2) Standard - Full deployment with search, connectors, and RAG"
$modeChoice = Prompt-OrDefault "Choose a mode (1 or 2) [default: 1]" "1"
if ($modeChoice -eq "2") {
Print-Info "Selected: Standard mode"
} else {
$script:LiteMode = $true
Print-Info "Selected: Lite mode"
if (-not (Ensure-OnyxFile $LiteOverlayPath "$($script:GitHubRawUrl)/$($script:LiteComposeFile)" $script:LiteComposeFile)) { exit 1 }
} else {
Print-Info "Selected: Standard mode"
}
}
@@ -354,8 +358,7 @@ function Invoke-OnyxShutdown {
return
}
if (-not (Initialize-ComposeCommand)) { Print-OnyxError "Docker Compose not found."; exit 1 }
$stopArgs = @("stop")
$result = Invoke-Compose -AutoDetect @stopArgs
$result = Invoke-Compose -AutoDetect stop
if ($result -ne 0) { Print-OnyxError "Failed to stop containers"; exit 1 }
Print-Success "Onyx containers stopped (paused)"
}
@@ -364,7 +367,7 @@ function Invoke-OnyxDeleteData {
Write-Host "`n=== WARNING: This will permanently delete all Onyx data ===`n" -ForegroundColor Red
Print-Warning "This action will remove all Onyx containers, volumes, files, and user data."
if (Test-Interactive) {
$confirm = Prompt-OrDefault "Type 'DELETE' to confirm" ""
$confirm = Read-Host "Type 'DELETE' to confirm"
if ($confirm -ne "DELETE") { Print-Info "Operation cancelled."; return }
} else {
Print-OnyxError "Cannot confirm destructive operation in non-interactive mode."
@@ -372,8 +375,7 @@ function Invoke-OnyxDeleteData {
}
$deployDir = Join-Path $script:InstallRoot "deployment"
if ((Test-Path (Join-Path $deployDir "docker-compose.yml")) -and (Initialize-ComposeCommand)) {
$downArgs = @("down", "-v")
$result = Invoke-Compose -AutoDetect @downArgs
$result = Invoke-Compose -AutoDetect down -v
if ($result -eq 0) { Print-Success "Containers and volumes removed" }
else { Print-OnyxError "Failed to remove containers" }
}
@@ -720,7 +722,6 @@ function Invoke-WslInstall {
# Ensure WSL2 is available
Invoke-NativeQuiet { wsl --status }
if ($LASTEXITCODE -ne 0) {
if (-not (Confirm-Action "WSL2 (Windows Subsystem for Linux)")) { exit 1 }
Print-Info "Installing WSL2..."
try {
$proc = Start-Process wsl -ArgumentList "--install", "--no-distribution" -Wait -PassThru -NoNewWindow
@@ -807,7 +808,7 @@ function Main {
if (Test-Interactive) {
Write-Host "`nPlease acknowledge and press Enter to continue..." -ForegroundColor Yellow
$null = Prompt-OrDefault "" ""
Read-Host | Out-Null
} else {
Write-Host "`nRunning in non-interactive mode - proceeding automatically..." -ForegroundColor Yellow
}
@@ -903,8 +904,8 @@ function Main {
if ($resourceWarning) {
Print-Warning "Onyx recommends at least $($script:ExpectedDockerRamGB)GB RAM and $($script:ExpectedDiskGB)GB disk for standard mode."
Print-Warning "Lite mode requires less (1-4GB RAM, 8-16GB disk) but has no vector database."
$reply = (Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y").Trim().ToLower()
if ($reply -notmatch '^y') { Print-Info "Installation cancelled."; exit 1 }
$reply = Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y"
if ($reply -notmatch '^[Yy]') { Print-Info "Installation cancelled."; exit 1 }
Print-Info "Proceeding despite resource limitations..."
}
@@ -926,13 +927,22 @@ function Main {
if ($composeVersion -ne "unknown" -and (Compare-SemVer $composeVersion "2.24.0") -lt 0) {
Print-Warning "Docker Compose $composeVersion is older than 2.24.0 (required for env_file format)."
Print-Info "Update Docker Desktop or install a newer Docker Compose. Installation may fail."
$reply = (Prompt-OrDefault "Continue anyway? (Y/n)" "y").Trim().ToLower()
if ($reply -notmatch '^y') { exit 1 }
$reply = Prompt-OrDefault "Continue anyway? (Y/n)" "y"
if ($reply -notmatch '^[Yy]') { exit 1 }
}
$liteOverlayPath = Join-Path $deploymentDir $script:LiteComposeFile
if ($script:LiteMode) {
if (-not (Ensure-OnyxFile $liteOverlayPath "$($script:GitHubRawUrl)/$($script:LiteComposeFile)" $script:LiteComposeFile)) { exit 1 }
} elseif (Test-Path $liteOverlayPath) {
if (Test-Path (Join-Path $deploymentDir ".env")) {
Print-Warning "Existing lite overlay found but -Lite was not passed."
$reply = Prompt-OrDefault "Remove lite overlay and switch to standard mode? (y/N)" "n"
if ($reply -match '^[Yy]') { Remove-Item -Force $liteOverlayPath; Print-Info "Switched to standard mode" }
else { $script:LiteMode = $true; Print-Info "Keeping lite mode" }
} else {
Remove-Item -Force $liteOverlayPath
}
}
$envTemplateDest = Join-Path $deploymentDir "env.template"
@@ -952,8 +962,7 @@ function Main {
# Check if services are already running
if ((Test-Path $composeDest) -and (Initialize-ComposeCommand)) {
$running = @()
$psArgs = @("ps", "-q")
try { $running = @(Invoke-Compose -AutoDetect @psArgs 2>$null | Where-Object { $_ }) } catch { }
try { $running = @(Invoke-Compose -AutoDetect ps -q 2>$null | Where-Object { $_ }) } catch { }
if ($running.Count -gt 0) {
Print-OnyxError "Onyx services are currently running!"
Print-Info "Run '.\install.ps1 -Shutdown' first, then re-run this script."
@@ -1019,12 +1028,6 @@ function Main {
Print-Info "You can customize .env later for OAuth/SAML, AI models, domain settings, and Craft."
}
# Clean up stale lite overlay if standard mode was selected
if (-not $script:LiteMode -and (Test-Path $liteOverlayPath)) {
Remove-Item -Force $liteOverlayPath
Print-Info "Removed previous lite overlay (switching to standard mode)"
}
# ── Step 6: Check Ports ───────────────────────────────────────────────
Print-Step "Checking for available ports"
$availablePort = Find-AvailablePort 3000
@@ -1034,7 +1037,7 @@ function Main {
Print-Success "Using port $availablePort for nginx"
$currentImageTag = Get-EnvFileValue -Path $envFile -Key "IMAGE_TAG"
$useLatest = ($currentImageTag -eq "edge" -or $currentImageTag -eq "latest" -or $currentImageTag -match '^craft-')
$useLatest = ($currentImageTag -eq "latest" -or $currentImageTag -match '^craft-')
if ($useLatest) { Print-Info "Using '$currentImageTag' tag - will force pull and recreate containers" }
# For pinned version tags, re-download config files from that tag so the
@@ -1066,9 +1069,8 @@ function Main {
# ── Step 8: Start Services ────────────────────────────────────────────
Print-Step "Starting Onyx services"
Print-Info "Launching containers..."
$upArgs = @("up", "-d")
if ($useLatest) { $upArgs += @("--pull", "always", "--force-recreate") }
$upResult = Invoke-Compose @upArgs
if ($useLatest) { $upResult = Invoke-Compose up -d --pull always --force-recreate }
else { $upResult = Invoke-Compose up -d }
if ($upResult -ne 0) { Print-OnyxError "Failed to start Onyx services"; exit 1 }
# ── Step 9: Container Health ──────────────────────────────────────────
@@ -1076,8 +1078,7 @@ function Main {
Start-Sleep -Seconds 10
$restartIssues = $false
$containerIds = @()
$psArgs = @("ps", "-q")
try { $containerIds = @(Invoke-Compose @psArgs 2>$null | Where-Object { $_ }) } catch { }
try { $containerIds = @(Invoke-Compose ps -q 2>$null | Where-Object { $_ }) } catch { }
foreach ($cid in $containerIds) {
if ([string]::IsNullOrWhiteSpace($cid)) { continue }

View File

@@ -96,8 +96,8 @@ fi
# When --lite is passed as a flag, lower resource thresholds early (before the
# resource check). When lite is chosen interactively, the thresholds are adjusted
# after the resource check has already passed with the standard thresholds —
# which is the safer direction.
# inside the new-deployment flow, after the resource check has already passed
# with the standard thresholds — which is the safer direction.
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
@@ -110,6 +110,9 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
# Build the -f flags for docker compose.
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
# (used by shutdown/delete-data so users don't need to remember --lite).
# Without the argument, the lite overlay is only included when --lite was
# explicitly passed — preventing install/start from silently staying in
# lite mode just because the file exists on disk from a prior run.
compose_file_args() {
local auto_detect="${1:-false}"
local args="-f docker-compose.yml"
@@ -174,42 +177,34 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]] && [[ -r /dev/tty ]] && [[ -w /dev/tty ]]
}
read_prompt_line() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r REPLY < /dev/tty || REPLY=""
}
read_prompt_char() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r -n 1 REPLY < /dev/tty || REPLY=""
printf "\n" > /dev/tty
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
}
prompt_or_default() {
local prompt_text="$1"
local default_value="$2"
read_prompt_line "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
if is_interactive; then
read -p "$prompt_text" -r REPLY
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
}
prompt_yn_or_default() {
local prompt_text="$1"
local default_value="$2"
read_prompt_char "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
if is_interactive; then
read -p "$prompt_text" -n 1 -r
echo ""
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
}
confirm_action() {
@@ -310,8 +305,8 @@ if [ "$DELETE_DATA_MODE" = true ]; then
echo " • All user data and documents"
echo ""
if is_interactive; then
prompt_or_default "Are you sure you want to continue? Type 'DELETE' to confirm: " ""
echo "" > /dev/tty
read -p "Are you sure you want to continue? Type 'DELETE' to confirm: " -r
echo ""
if [ "$REPLY" != "DELETE" ]; then
print_info "Operation cancelled."
exit 0
@@ -505,7 +500,7 @@ echo ""
if is_interactive; then
echo -e "${YELLOW}${BOLD}Please acknowledge and press Enter to continue...${NC}"
read_prompt_line ""
read -r
echo ""
else
echo -e "${YELLOW}${BOLD}Running in non-interactive mode - proceeding automatically...${NC}"
@@ -750,48 +745,25 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
fi
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo " 2) Standard - Full deployment with search, connectors, and RAG"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
print_info "Selected: Standard mode"
;;
*)
LITE_MODE=true
print_info "Selected: Lite mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Handle lite overlay file based on selected mode
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
if [[ "$LITE_MODE" = true ]]; then
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
print_warning "Existing lite overlay found but --lite was not passed."
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
LITE_MODE=true
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed lite overlay (switching to standard mode)"
fi
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
fi
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
@@ -854,22 +826,22 @@ if [ -f "$ENV_FILE" ]; then
if [ "$REPLY" = "update" ]; then
print_info "Update selected. Which tag would you like to deploy?"
echo ""
echo "• Press Enter for edge (recommended)"
echo "• Press Enter for latest (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
if [ "$INCLUDE_CRAFT" = true ]; then
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
prompt_or_default "Enter tag [default: edge]: " "edge"
prompt_or_default "Enter tag [default: latest]: " "latest"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest version"
else
print_info "Selected: $VERSION"
fi
@@ -921,6 +893,45 @@ else
print_info "No existing .env file found. Setting up new deployment..."
echo ""
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Standard - Full deployment with search, connectors, and RAG"
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
LITE_MODE=true
print_info "Selected: Lite mode"
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
;;
*)
print_info "Selected: Standard mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
# Validate lite + craft combination (could now be set interactively)
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
# Adjust resource expectations for lite mode
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Ask for version
print_info "Which tag would you like to deploy?"
echo ""
@@ -931,18 +942,18 @@ else
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
echo "• Press Enter for edge (recommended)"
echo "• Press Enter for latest (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
prompt_or_default "Enter tag [default: edge]: " "edge"
prompt_or_default "Enter tag [default: latest]: " "latest"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest tag"
else
print_info "Selected: $VERSION"
fi
@@ -1100,15 +1111,15 @@ fi
export HOST_PORT=$AVAILABLE_PORT
print_success "Using port $AVAILABLE_PORT for nginx"
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
# Determine if we're using the latest tag or a craft tag (both should force pull)
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
USE_LATEST=true
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
else
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
print_info "Using 'latest' tag - will force pull and recreate containers"
fi
else
USE_LATEST=false

View File

@@ -127,7 +127,6 @@ Inputs (common):
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
- `postgres_username`, `postgres_password`
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
### `vpc`
- Builds a VPC sized for EKS with multiple private and public subnets

View File

@@ -88,8 +88,6 @@ module "waf" {
tags = local.merged_tags
# WAF configuration with sensible defaults
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
geo_restriction_countries = var.waf_geo_restriction_countries

View File

@@ -117,18 +117,6 @@ variable "waf_rate_limit_requests_per_5_minutes" {
default = 2000
}
variable "waf_allowed_ip_cidrs" {
type = list(string)
description = "Optional IPv4 CIDR ranges allowed through the WAF. Leave empty to disable IP allowlisting."
default = []
}
variable "waf_common_rule_set_count_rules" {
type = list(string)
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
default = []
}
variable "waf_api_rate_limit_requests_per_5_minutes" {
type = number
description = "Rate limit for API requests per 5 minutes per IP address"

View File

@@ -1,20 +1,6 @@
locals {
name = var.name
tags = var.tags
ip_allowlist_enabled = length(var.allowed_ip_cidrs) > 0
managed_rule_priority = local.ip_allowlist_enabled ? 1 : 0
}
resource "aws_wafv2_ip_set" "allowed_ips" {
count = local.ip_allowlist_enabled ? 1 : 0
name = "${local.name}-allowed-ips"
description = "IP allowlist for ${local.name}"
scope = "REGIONAL"
ip_address_version = "IPV4"
addresses = var.allowed_ip_cidrs
tags = local.tags
name = var.name
tags = var.tags
}
# AWS WAFv2 Web ACL
@@ -27,38 +13,10 @@ resource "aws_wafv2_web_acl" "main" {
allow {}
}
dynamic "rule" {
for_each = local.ip_allowlist_enabled ? [1] : []
content {
name = "BlockRequestsOutsideAllowedIPs"
priority = 1
action {
block {}
}
statement {
not_statement {
statement {
ip_set_reference_statement {
arn = aws_wafv2_ip_set.allowed_ips[0].arn
}
}
}
}
visibility_config {
cloudwatch_metrics_enabled = true
metric_name = "BlockRequestsOutsideAllowedIPsMetric"
sampled_requests_enabled = true
}
}
}
# AWS Managed Rules - Core Rule Set
rule {
name = "AWSManagedRulesCommonRuleSet"
priority = 1 + local.managed_rule_priority
priority = 1
override_action {
none {}
@@ -68,16 +26,6 @@ resource "aws_wafv2_web_acl" "main" {
managed_rule_group_statement {
name = "AWSManagedRulesCommonRuleSet"
vendor_name = "AWS"
dynamic "rule_action_override" {
for_each = var.common_rule_set_count_rules
content {
name = rule_action_override.value
action_to_use {
count {}
}
}
}
}
}
@@ -91,7 +39,7 @@ resource "aws_wafv2_web_acl" "main" {
# AWS Managed Rules - Known Bad Inputs
rule {
name = "AWSManagedRulesKnownBadInputsRuleSet"
priority = 2 + local.managed_rule_priority
priority = 2
override_action {
none {}
@@ -114,7 +62,7 @@ resource "aws_wafv2_web_acl" "main" {
# Rate Limiting Rule
rule {
name = "RateLimitRule"
priority = 3 + local.managed_rule_priority
priority = 3
action {
block {}
@@ -139,7 +87,7 @@ resource "aws_wafv2_web_acl" "main" {
for_each = length(var.geo_restriction_countries) > 0 ? [1] : []
content {
name = "GeoRestrictionRule"
priority = 4 + local.managed_rule_priority
priority = 4
action {
block {}
@@ -162,7 +110,7 @@ resource "aws_wafv2_web_acl" "main" {
# IP Rate Limiting
rule {
name = "APIRateLimitRule"
priority = 5 + local.managed_rule_priority
priority = 5
action {
block {}
@@ -185,7 +133,7 @@ resource "aws_wafv2_web_acl" "main" {
# SQL Injection Protection
rule {
name = "AWSManagedRulesSQLiRuleSet"
priority = 6 + local.managed_rule_priority
priority = 6
override_action {
none {}
@@ -208,7 +156,7 @@ resource "aws_wafv2_web_acl" "main" {
# Anonymous IP Protection
rule {
name = "AWSManagedRulesAnonymousIpList"
priority = 7 + local.managed_rule_priority
priority = 7
override_action {
none {}

View File

@@ -9,18 +9,6 @@ variable "tags" {
default = {}
}
variable "allowed_ip_cidrs" {
type = list(string)
description = "Optional IPv4 CIDR ranges allowed to reach the application. Leave empty to disable IP allowlisting."
default = []
}
variable "common_rule_set_count_rules" {
type = list(string)
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
default = []
}
variable "rate_limit_requests_per_5_minutes" {
type = number
description = "Rate limit for requests per 5 minutes per IP address"

View File

@@ -3839,9 +3839,9 @@
}
},
"node_modules/flatted": {
"version": "3.4.2",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
"version": "3.3.3",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz",
"integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==",
"dev": true,
"license": "ISC"
},

View File

@@ -1,38 +1,33 @@
"use client";
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
interface ActionsContainerProps {
type: "head" | "cell";
children: React.ReactNode;
size?: TableSize;
/** Pass-through click handler (e.g. stopPropagation on body cells). */
onClick?: (e: React.MouseEvent) => void;
children: React.ReactNode;
}
export default function ActionsContainer({
type,
children,
size,
onClick,
}: ActionsContainerProps) {
const size = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const Tag = type === "head" ? "th" : "td";
return (
<Tag
className="tbl-actions"
data-type={type}
data-size={size}
data-size={resolvedSize}
onClick={onClick}
>
<div
className={cn(
"flex h-full items-center",
type === "cell" ? "justify-end" : "justify-center"
)}
>
{children}
</div>
<div className="flex h-full items-center justify-center">{children}</div>
</Tag>
);
}

View File

@@ -8,7 +8,6 @@ import {
type SortingState,
} from "@tanstack/react-table";
import { Button, LineItemButton } from "@opal/components";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
@@ -21,6 +20,7 @@ import Text from "@/refresh-components/texts/Text";
interface SortingPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
sorting: SortingState;
size?: "md" | "lg";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
@@ -29,11 +29,11 @@ interface SortingPopoverProps<TData extends RowData = RowData> {
function SortingPopover<TData extends RowData>({
table,
sorting,
size = "lg",
footerText,
ascendingLabel = "Ascending",
descendingLabel = "Descending",
}: SortingPopoverProps<TData>) {
const size = useTableSize();
const [open, setOpen] = useState(false);
const sortableColumns = table
.getAllLeafColumns()
@@ -158,6 +158,7 @@ function SortingPopover<TData extends RowData>({
// ---------------------------------------------------------------------------
interface CreateSortingColumnOptions {
size?: "md" | "lg";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
@@ -176,6 +177,7 @@ function createSortingColumn<TData>(
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={options?.size}
footerText={options?.footerText}
ascendingLabel={options?.ascendingLabel}
descendingLabel={options?.descendingLabel}

View File

@@ -8,7 +8,6 @@ import {
type VisibilityState,
} from "@tanstack/react-table";
import { Button, LineItemButton, Tag } from "@opal/components";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import { SvgColumn, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
@@ -20,13 +19,14 @@ import Divider from "@/refresh-components/Divider";
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
columnVisibility: VisibilityState;
size?: "md" | "lg";
}
function ColumnVisibilityPopover<TData extends RowData>({
table,
columnVisibility,
size = "lg",
}: ColumnVisibilityPopoverProps<TData>) {
const size = useTableSize();
const [open, setOpen] = useState(false);
// User-defined columns only (exclude internal qualifier/actions)
@@ -87,7 +87,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
// Column definition factory
// ---------------------------------------------------------------------------
function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
interface CreateColumnVisibilityColumnOptions {
size?: "md" | "lg";
}
function createColumnVisibilityColumn<TData>(
options?: CreateColumnVisibilityColumnOptions
): ColumnDef<TData, unknown> {
return {
id: "__columnVisibility",
size: 44,
@@ -98,6 +104,7 @@ function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
<ColumnVisibilityPopover
table={table}
columnVisibility={table.getState().columnVisibility}
size={options?.size}
/>
),
cell: () => null,

View File

@@ -57,10 +57,9 @@ function DragOverlayRowInner<TData>({
<QualifierContainer key={cell.id} type="cell">
<TableQualifier
content={qualifierColumn.content}
icon={qualifierColumn.getContent?.(row.original)}
initials={qualifierColumn.getInitials?.(row.original)}
icon={qualifierColumn.getIcon?.(row.original)}
imageSrc={qualifierColumn.getImageSrc?.(row.original)}
imageAlt={qualifierColumn.getImageAlt?.(row.original)}
background={qualifierColumn.background}
selectable={isSelectable}
selected={isSelectable && row.getIsSelected()}
/>

View File

@@ -1,8 +1,10 @@
"use client";
import { cn } from "@opal/utils";
import { Button, Pagination, SelectButton } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import { SvgEye, SvgXCircle } from "@opal/icons";
import type { ReactNode } from "react";
@@ -43,6 +45,9 @@ interface FooterSelectionModeProps {
onPageChange: (page: number) => void;
/** Unit label for count pagination. @default "items" */
units?: string;
/** Controls overall footer sizing. `"lg"` (default) or `"md"`. */
size?: TableSize;
className?: string;
}
/**
@@ -68,6 +73,7 @@ interface FooterSummaryModeProps {
leftExtra?: ReactNode;
/** Unit label for the summary text, e.g. "users". */
units?: string;
className?: string;
}
/**
@@ -104,7 +110,11 @@ export default function Footer(props: FooterProps) {
const isSmall = resolvedSize === "md";
return (
<div
className="table-footer flex w-full items-center justify-between border-t border-border-01"
className={cn(
"table-footer",
"flex w-full items-center justify-between border-t border-border-01",
props.className
)}
data-size={resolvedSize}
>
{/* Left side */}

View File

@@ -1,10 +1,10 @@
"use client";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
interface QualifierContainerProps {
type: "head" | "cell";
children?: React.ReactNode;
size?: TableSize;
/** Pass-through click handler (e.g. stopPropagation on body cells). */
onClick?: (e: React.MouseEvent) => void;
}
@@ -12,9 +12,11 @@ interface QualifierContainerProps {
export default function QualifierContainer({
type,
children,
size,
onClick,
}: QualifierContainerProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const Tag = type === "head" ? "th" : "td";

View File

@@ -7,7 +7,6 @@ row selection, drag-and-drop reordering, and server-side mode.
```tsx
import { Table, createTableColumns } from "@opal/components";
import { SvgUser } from "@opal/icons";
interface User {
id: string;
@@ -19,10 +18,11 @@ interface User {
const tc = createTableColumns<User>();
const columns = [
tc.qualifier({ content: "icon", getContent: () => SvgUser }),
tc.qualifier({ content: "avatar-user", getInitials: (r) => r.name?.[0] ?? "?" }),
tc.column("email", {
header: "Name",
weight: 22,
minWidth: 140,
cell: (email, row) => <span>{row.name ?? email}</span>,
}),
tc.column("status", {
@@ -40,7 +40,7 @@ function UsersTable({ users }: { users: User[] }) {
columns={columns}
getRowId={(r) => r.id}
pageSize={10}
footer={{}}
footer={{ mode: "summary" }}
/>
);
}
@@ -55,7 +55,7 @@ function UsersTable({ users }: { users: User[] }) {
| `getRowId` | `(row: TData) => string` | required | Unique row identifier |
| `pageSize` | `number` | `10` | Rows per page (`Infinity` disables pagination) |
| `size` | `"md" \| "lg"` | `"lg"` | Density variant |
| `footer` | `DataTableFooterConfig` | — | Footer configuration (mode is derived from `selectionBehavior`) |
| `footer` | `DataTableFooterConfig` | — | Footer mode (`"selection"` or `"summary"`) |
| `initialSorting` | `SortingState` | — | Initial sort state |
| `initialColumnVisibility` | `VisibilityState` | — | Initial column visibility |
| `draggable` | `DataTableDraggableConfig` | — | Enable drag-and-drop reordering |
@@ -63,6 +63,7 @@ function UsersTable({ users }: { users: User[] }) {
| `onRowClick` | `(row: TData) => void` | — | Row click handler |
| `searchTerm` | `string` | — | Global text filter |
| `height` | `number \| string` | — | Max scrollable height |
| `headerBackground` | `string` | — | Sticky header background |
| `serverSide` | `ServerSideConfig` | — | Server-side pagination/sorting/filtering |
| `emptyState` | `ReactNode` | — | Empty state content |
@@ -75,8 +76,7 @@ function UsersTable({ users }: { users: User[] }) {
- `tc.displayColumn(opts)` — non-accessor custom column
- `tc.actions(opts)` — trailing actions column with visibility/sorting popovers
## Footer
## Footer Modes
The footer mode is derived automatically from `selectionBehavior`:
- **Selection footer** (when `selectionBehavior` is `"single-select"` or `"multi-select"`) — shows selection count, optional view/clear buttons, count pagination
- **Summary footer** (when `selectionBehavior` is `"no-select"` or omitted) — shows "Showing X\~Y of Z", list pagination, optional extra element
- **`"selection"`** — shows selection count, optional view/clear buttons, count pagination
- **`"summary"`** — shows "Showing X~Y of Z", list pagination, optional extra element

View File

@@ -1,6 +1,5 @@
import type { Meta, StoryObj } from "@storybook/react";
import { Table, createTableColumns } from "@opal/components";
import { SvgUser } from "@opal/icons";
// ---------------------------------------------------------------------------
// Sample data
@@ -109,14 +108,17 @@ const tc = createTableColumns<User>();
const columns = [
tc.qualifier({
content: "icon",
getContent: () => SvgUser,
background: true,
content: "avatar-user",
getInitials: (r) =>
r.name
.split(" ")
.map((n) => n[0])
.join(""),
}),
tc.column("name", { header: "Name", weight: 25 }),
tc.column("email", { header: "Email", weight: 30 }),
tc.column("role", { header: "Role", weight: 15 }),
tc.column("status", { header: "Status", weight: 15 }),
tc.column("name", { header: "Name", weight: 25, minWidth: 120 }),
tc.column("email", { header: "Email", weight: 30, minWidth: 160 }),
tc.column("role", { header: "Role", weight: 15, minWidth: 80 }),
tc.column("status", { header: "Status", weight: 15, minWidth: 80 }),
tc.actions(),
];
@@ -140,7 +142,7 @@ export const Default: Story = {
columns={columns}
getRowId={(r) => r.id}
pageSize={8}
footer={{}}
footer={{ mode: "summary" }}
/>
),
};

View File

@@ -1,20 +1,24 @@
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
interface TableCellProps
extends WithoutStyles<React.TdHTMLAttributes<HTMLTableCellElement>> {
children: React.ReactNode;
size?: TableSize;
/** Explicit pixel width for the cell. */
width?: number;
}
export default function TableCell({
size,
width,
children,
...props
}: TableCellProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
return (
<td
className="tbl-cell overflow-hidden"

View File

@@ -1,8 +1,5 @@
"use client";
import React from "react";
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
@@ -12,15 +9,20 @@ import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
type TableSize = Extract<SizeVariants, "md" | "lg">;
type TableVariant = "rows" | "cards";
type TableQualifier = "simple" | "avatar" | "icon";
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
interface TableProps
extends WithoutStyles<React.TableHTMLAttributes<HTMLTableElement>> {
ref?: React.Ref<HTMLTableElement>;
/** Size preset for the table. @default "lg" */
size?: TableSize;
/** Visual row variant. @default "cards" */
variant?: TableVariant;
/** Row selection behavior. @default "no-select" */
selectionBehavior?: SelectionBehavior;
/** Leading qualifier column type. @default null */
qualifier?: TableQualifier;
/** Height behavior. `"fit"` = shrink to content, `"full"` = fill available space. */
heightVariant?: ExtremaSizeVariants;
/** Explicit pixel width for the table (e.g. from `table.getTotalSize()`).
@@ -36,13 +38,14 @@ interface TableProps
function Table({
ref,
size = "lg",
variant = "cards",
selectionBehavior = "no-select",
qualifier = "simple",
heightVariant,
width,
...props
}: TableProps) {
const size = useTableSize();
return (
<table
ref={ref}
@@ -51,6 +54,7 @@ function Table({
data-size={size}
data-variant={variant}
data-selection={selectionBehavior}
data-qualifier={qualifier}
data-height={heightVariant}
{...props}
/>
@@ -58,4 +62,10 @@ function Table({
}
export default Table;
export type { TableProps, TableSize, TableVariant, SelectionBehavior };
export type {
TableProps,
TableSize,
TableVariant,
TableQualifier,
SelectionBehavior,
};

View File

@@ -1,6 +1,7 @@
import { cn } from "@opal/utils";
import Text from "@/refresh-components/texts/Text";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
import { Button } from "@opal/components";
import { SvgChevronDown, SvgChevronUp, SvgHandle, SvgSort } from "@opal/icons";
@@ -29,6 +30,8 @@ interface TableHeadCustomProps {
icon?: (sorted: SortDirection) => IconFunctionComponent;
/** Text alignment for the column. Defaults to `"left"`. */
alignment?: "left" | "center" | "right";
/** Cell density. `"md"` uses tighter padding for denser layouts. */
size?: TableSize;
/** Column width in pixels. Applied as an inline style on the `<th>`. */
width?: number;
/** When `true`, shows a bottom border on hover. Defaults to `true`. */
@@ -78,11 +81,13 @@ export default function TableHead({
resizable,
onResizeStart,
alignment = "left",
size,
width,
bottomBorder = true,
...thProps
}: TableHeadProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const isSmall = resolvedSize === "md";
return (
<th

View File

@@ -3,13 +3,19 @@
import React from "react";
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import { SvgUser } from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
import type { QualifierContentType } from "@opal/components/table/types";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import Text from "@/refresh-components/texts/Text";
interface TableQualifierProps {
className?: string;
/** Content type displayed in the qualifier */
content: QualifierContentType;
/** Size variant */
size?: TableSize;
/** Disables interaction */
disabled?: boolean;
/** Whether to show a selection checkbox overlay */
@@ -18,33 +24,54 @@ interface TableQualifierProps {
selected?: boolean;
/** Called when the checkbox is toggled */
onSelectChange?: (selected: boolean) => void;
/** Icon component to render (for "icon" content). */
/** Icon component to render (for "icon" content type) */
icon?: IconFunctionComponent;
/** Image source URL (for "image" content). */
/** Image source URL (for "image" content type) */
imageSrc?: string;
/** Image alt text (for "image" content). */
/** Image alt text */
imageAlt?: string;
/** Show a tinted background container behind the content. */
background?: boolean;
/** User initials (for "avatar-user" content type) */
initials?: string;
}
const iconSizes = {
lg: 28,
md: 24,
lg: 16,
md: 14,
} as const;
function getOverlayStyles(selected: boolean, disabled: boolean) {
function getQualifierStyles(selected: boolean, disabled: boolean) {
if (disabled) {
return selected ? "flex bg-action-link-00" : "hidden";
return {
container: "bg-background-neutral-03",
icon: "stroke-text-02",
overlay: selected ? "flex bg-action-link-00" : "hidden",
overlayImage: selected ? "flex bg-mask-01 backdrop-blur-02" : "hidden",
};
}
if (selected) {
return "flex bg-action-link-00";
return {
container: "bg-action-link-00",
icon: "stroke-text-03",
overlay: "flex bg-action-link-00",
overlayImage: "flex bg-mask-01 backdrop-blur-02",
};
}
return "flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01";
return {
container: "bg-background-tint-01",
icon: "stroke-text-03",
overlay:
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01",
overlayImage:
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-mask-01 group-hover/row:backdrop-blur-02 group-focus-within/row:backdrop-blur-02",
};
}
function TableQualifier({
className,
content,
size,
disabled = false,
selectable = false,
selected = false,
@@ -52,67 +79,100 @@ function TableQualifier({
icon: Icon,
imageSrc,
imageAlt = "",
background = false,
initials,
}: TableQualifierProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const isRound = content === "avatar-icon" || content === "avatar-user";
const iconSize = iconSizes[resolvedSize];
const overlayStyles = getOverlayStyles(selected, disabled);
const styles = getQualifierStyles(selected, disabled);
function renderContent() {
switch (content) {
case "icon":
return Icon ? <Icon size={iconSize} /> : null;
return Icon ? <Icon size={iconSize} className={styles.icon} /> : null;
case "simple":
return null;
case "image":
return imageSrc ? (
<img
src={imageSrc}
alt={imageAlt}
className="h-full w-full rounded-08 object-cover"
className={cn(
"h-full w-full object-cover",
isRound ? "rounded-full" : "rounded-08"
)}
/>
) : null;
case "simple":
case "avatar-icon":
return <SvgUser size={iconSize} className={styles.icon} />;
case "avatar-user":
return (
<div
className={cn(
"flex items-center justify-center rounded-full bg-background-neutral-inverted-00",
resolvedSize === "lg" ? "h-7 w-7" : "h-6 w-6"
)}
>
<Text
inverted
secondaryAction
text05
className="select-none uppercase"
>
{initials}
</Text>
</div>
);
default:
return null;
}
}
const inner = renderContent();
const showBackground = background && content !== "simple";
return (
<div
className={cn(
"group relative inline-flex shrink-0 items-center justify-center",
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
disabled ? "cursor-not-allowed" : "cursor-default"
disabled ? "cursor-not-allowed" : "cursor-default",
className
)}
>
{showBackground ? (
{/* Inner qualifier container — no background for "simple" */}
{content !== "simple" && (
<div
className={cn(
"flex items-center justify-center overflow-hidden rounded-08 transition-colors",
"flex items-center justify-center overflow-hidden transition-colors",
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
disabled
? "bg-background-neutral-03"
: selected
? "bg-action-link-00"
: "bg-background-tint-01"
isRound ? "rounded-full" : "rounded-08",
styles.container,
content === "image" && disabled && !selected && "opacity-50"
)}
>
{inner}
{renderContent()}
</div>
) : (
inner
)}
{/* Selection overlay */}
{selectable && (
<div
className={cn(
"absolute inset-0 items-center justify-center rounded-08",
content === "simple" ? "flex" : overlayStyles
"absolute inset-0 items-center justify-center",
content === "simple"
? "flex"
: isRound
? "rounded-full"
: "rounded-08",
content === "simple"
? "flex"
: content === "image"
? styles.overlayImage
: styles.overlay
)}
>
<Checkbox

View File

@@ -2,6 +2,7 @@
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
import { useSortable } from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
@@ -11,7 +12,7 @@ import { SvgHandle } from "@opal/icons";
// Types
// ---------------------------------------------------------------------------
export interface TableRowProps
interface TableRowProps
extends WithoutStyles<React.HTMLAttributes<HTMLTableRowElement>> {
ref?: React.Ref<HTMLTableRowElement>;
selected?: boolean;
@@ -21,6 +22,8 @@ export interface TableRowProps
sortableId?: string;
/** Show drag handle overlay. Defaults to true when sortableId is set. */
showDragHandle?: boolean;
/** Size variant for the drag handle */
size?: TableSize;
}
// ---------------------------------------------------------------------------
@@ -30,13 +33,15 @@ export interface TableRowProps
function SortableTableRow({
sortableId,
showDragHandle = true,
size,
selected,
disabled,
ref: _externalRef,
children,
...props
}: TableRowProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const {
attributes,
@@ -100,9 +105,10 @@ function SortableTableRow({
// Main component
// ---------------------------------------------------------------------------
export default function TableRow({
function TableRow({
sortableId,
showDragHandle,
size,
selected,
disabled,
ref,
@@ -113,6 +119,7 @@ export default function TableRow({
<SortableTableRow
sortableId={sortableId}
showDragHandle={showDragHandle}
size={size}
selected={selected}
disabled={disabled}
ref={ref}
@@ -131,3 +138,6 @@ export default function TableRow({
/>
);
}
export default TableRow;
export type { TableRowProps };

View File

@@ -25,14 +25,18 @@ import type { SortDirection } from "@opal/components/table/TableHead";
interface QualifierConfig<TData> {
/** Content type for body-row `<TableQualifier>`. @default "simple" */
content?: QualifierContentType;
/** Return the icon component to render for a row (for "icon" content). */
getContent?: (row: TData) => IconFunctionComponent;
/** Return the image URL to render for a row (for "image" content). */
/** Content type for the header `<TableQualifier>`. @default "simple" */
headerContentType?: QualifierContentType;
/** Extract initials from a row (for "avatar-user" content). */
getInitials?: (row: TData) => string;
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
getIcon?: (row: TData) => IconFunctionComponent;
/** Extract image src from a row (for "image" content). */
getImageSrc?: (row: TData) => string;
/** Return the image alt text for a row (for "image" content). @default "" */
getImageAlt?: (row: TData) => string;
/** Show a tinted background container behind the content. @default false */
background?: boolean;
/** Whether to show selection checkboxes on the qualifier. @default true */
selectable?: boolean;
/** Whether to render qualifier content in the header. @default true */
header?: boolean;
}
// ---------------------------------------------------------------------------
@@ -54,6 +58,8 @@ interface DataColumnConfig<TData, TValue> {
icon?: (sorted: SortDirection) => IconFunctionComponent;
/** Column weight for proportional distribution. @default 20 */
weight?: number;
/** Minimum column width in pixels. @default 50 */
minWidth?: number;
}
// ---------------------------------------------------------------------------
@@ -126,9 +132,9 @@ interface TableColumnsBuilder<TData> {
* ```ts
* const tc = createTableColumns<TeamMember>();
* const columns = [
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
* tc.column("name", { header: "Name", weight: 23 }),
* tc.column("email", { header: "Email", weight: 28 }),
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
* tc.column("email", { header: "Email", weight: 28, minWidth: 150 }),
* tc.actions(),
* ];
* ```
@@ -156,10 +162,12 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
width: (size: TableSize) =>
size === "md" ? { fixed: 36 } : { fixed: 44 },
content,
getContent: config?.getContent,
headerContentType: config?.headerContentType,
getInitials: config?.getInitials,
getIcon: config?.getIcon,
getImageSrc: config?.getImageSrc,
getImageAlt: config?.getImageAlt,
background: config?.background,
selectable: config?.selectable,
header: config?.header,
};
},
@@ -175,6 +183,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
enableHiding = true,
icon,
weight = 20,
minWidth = 50,
} = config;
const def = helper.accessor(accessor as any, {
@@ -192,7 +201,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
kind: "data",
id: accessor as string,
def,
width: { weight, minWidth: Math.max(header.length * 8 + 40, 80) },
width: { weight, minWidth },
icon,
};
},

View File

@@ -39,12 +39,15 @@ import type {
import type { TableSize } from "@opal/components/table/TableSizeContext";
// ---------------------------------------------------------------------------
// SelectionBehavior
// Qualifier × SelectionBehavior
// ---------------------------------------------------------------------------
type Qualifier = "simple" | "avatar" | "icon";
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
export type DataTableProps<TData> = BaseDataTableProps<TData> & {
/** Leading qualifier column type. @default "simple" */
qualifier?: Qualifier;
/** Row selection behavior. @default "no-select" */
selectionBehavior?: SelectionBehavior;
};
@@ -128,8 +131,8 @@ function processColumns<TData>(
* ```tsx
* const tc = createTableColumns<TeamMember>();
* const columns = [
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
* tc.column("name", { header: "Name", weight: 23 }),
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
* tc.column("email", { header: "Email", weight: 28 }),
* tc.actions(),
* ];
@@ -149,11 +152,13 @@ export function Table<TData>(props: DataTableProps<TData>) {
footer,
size = "lg",
variant = "cards",
qualifier = "simple",
selectionBehavior = "no-select",
onSelectionChange,
onRowClick,
searchTerm,
height,
headerBackground,
serverSide,
emptyState,
} = props;
@@ -161,15 +166,11 @@ export function Table<TData>(props: DataTableProps<TData>) {
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
// Whether the qualifier column should exist in the DOM.
// Derived from the column definitions: if a qualifier column exists with
// content !== "simple", always show it. If content === "simple" (or no
// qualifier column defined), show only for multi-select (checkboxes).
const qualifierColDef = columns.find(
(c): c is OnyxQualifierColumn<TData> => c.kind === "qualifier"
);
// "simple" only gets a qualifier column for multi-select (checkboxes).
// "simple" + no-select/single-select = no qualifier column — single-select
// uses row-level background coloring instead.
const hasQualifierColumn =
(qualifierColDef != null && qualifierColDef.content !== "simple") ||
selectionBehavior === "multi-select";
qualifier !== "simple" || selectionBehavior === "multi-select";
// 1. Process columns (memoized on columns + size)
const { tanstackColumns, widthConfig, qualifierColumn, columnKindMap } =
@@ -348,9 +349,15 @@ export function Table<TData>(props: DataTableProps<TData>) {
overflowY: "auto" as const,
}
: undefined),
...(headerBackground
? ({
"--table-header-bg": headerBackground,
} as React.CSSProperties)
: undefined),
}}
>
<TableElement
size={size}
variant={variant}
selectionBehavior={selectionBehavior}
width={
@@ -412,12 +419,14 @@ export function Table<TData>(props: DataTableProps<TData>) {
columnVisibility={
table.getState().columnVisibility
}
size={size}
/>
)}
{actionsDef.showSorting !== false && (
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={size}
footerText={actionsDef.sortingFooterText}
/>
)}
@@ -532,6 +541,12 @@ export function Table<TData>(props: DataTableProps<TData>) {
if (cellColDef?.kind === "qualifier") {
const qDef = cellColDef as OnyxQualifierColumn<TData>;
// Resolve content based on the qualifier prop:
// - "simple" renders nothing (checkbox only when selectable)
// - "avatar"/"icon" render from column config
const qualifierContent =
qualifier === "simple" ? "simple" : qDef.content;
return (
<QualifierContainer
key={cell.id}
@@ -539,11 +554,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
onClick={(e) => e.stopPropagation()}
>
<TableQualifier
content={qDef.content}
icon={qDef.getContent?.(row.original)}
content={qualifierContent}
initials={qDef.getInitials?.(row.original)}
icon={qDef.getIcon?.(row.original)}
imageSrc={qDef.getImageSrc?.(row.original)}
imageAlt={qDef.getImageAlt?.(row.original)}
background={qDef.background}
selectable={showQualifierCheckbox}
selected={
showQualifierCheckbox && row.getIsSelected()

View File

@@ -277,7 +277,7 @@ function createSplitterResizeHandler(
* const { containerRef, columnWidths, createResizeHandler } = useColumnWidths({
* headers: table.getHeaderGroups()[0].headers,
* fixedColumnIds: new Set(["actions"]),
* columnMinWidths: { name: 72, status: 80 },
* columnMinWidths: { name: 120, status: 80 },
* });
* ```
*/

View File

@@ -25,7 +25,8 @@
/* ---- TableHead ---- */
.table-head {
@apply relative;
@apply relative sticky top-0 z-20;
background: var(--table-header-bg, transparent);
}
.table-head[data-size="lg"] {
@apply px-2 py-1;
@@ -129,7 +130,8 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
/* ---- QualifierContainer ---- */
.tbl-qualifier[data-type="head"] {
@apply w-px whitespace-nowrap py-1;
@apply w-px whitespace-nowrap py-1 sticky top-0 z-20;
background: var(--table-header-bg, transparent);
}
.tbl-qualifier[data-type="head"][data-size="md"] {
@apply py-0.5;
@@ -145,10 +147,11 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
/* ---- ActionsContainer ---- */
.tbl-actions {
@apply w-px whitespace-nowrap px-1;
@apply sticky right-0 w-px whitespace-nowrap px-1;
}
.tbl-actions[data-type="head"] {
@apply px-2 py-1;
@apply z-30 sticky top-0 px-2 py-1;
background: var(--table-header-bg, transparent);
}
/* ---- Footer ---- */

View File

@@ -30,7 +30,12 @@ export type ColumnWidth = DataColumnWidth | FixedColumnWidth;
// Column kind discriminant
// ---------------------------------------------------------------------------
export type QualifierContentType = "simple" | "icon" | "image";
export type QualifierContentType =
| "icon"
| "simple"
| "image"
| "avatar-icon"
| "avatar-user";
export type OnyxColumnKind = "qualifier" | "data" | "display" | "actions";
@@ -51,14 +56,18 @@ export interface OnyxQualifierColumn<TData> extends OnyxColumnBase<TData> {
kind: "qualifier";
/** Content type for body-row `<TableQualifier>`. */
content: QualifierContentType;
/** Return the icon component to render for a row (for "icon" content). */
getContent?: (row: TData) => IconFunctionComponent;
/** Return the image URL to render for a row (for "image" content). */
/** Content type for the header `<TableQualifier>`. @default "simple" */
headerContentType?: QualifierContentType;
/** Extract initials from a row (for "avatar-user" content). */
getInitials?: (row: TData) => string;
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
getIcon?: (row: TData) => IconFunctionComponent;
/** Extract image src from a row (for "image" content). */
getImageSrc?: (row: TData) => string;
/** Return the image alt text for a row (for "image" content). @default "" */
getImageAlt?: (row: TData) => string;
/** Show a tinted background container behind the content. @default false */
background?: boolean;
/** Whether to show selection checkboxes on the qualifier. @default true */
selectable?: boolean;
/** Whether to render qualifier content in the header. @default true */
header?: boolean;
}
/** Data column — accessor-based column with sorting/resizing. */
@@ -165,6 +174,9 @@ export interface DataTableProps<TData> {
* Accepts a pixel number (e.g. `300`) or a CSS value string (e.g. `"50vh"`).
*/
height?: number | string;
/** Background color for the sticky header row, preventing rows from showing
* through when scrolling. Accepts any CSS color value. */
headerBackground?: string;
/**
* Enable server-side mode. When provided:
* - TanStack uses manualPagination/manualSorting/manualFiltering

6
web/package-lock.json generated
View File

@@ -10309,9 +10309,9 @@
}
},
"node_modules/flatted": {
"version": "3.4.2",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
"version": "3.4.1",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.1.tgz",
"integrity": "sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==",
"dev": true,
"license": "ISC"
},

View File

@@ -45,7 +45,7 @@ function MemoryTagWithTooltip({
<MemoriesModal
initialTargetMemoryId={memoryId}
initialTargetIndex={memoryIndex}
highlightOnOpen
highlightFirstOnOpen
/>
</memoriesModal.Provider>
{memoriesModal.isOpen ? (
@@ -56,14 +56,8 @@ function MemoryTagWithTooltip({
side="bottom"
className="bg-background-neutral-00 text-text-01 shadow-md max-w-[17.5rem] p-1"
tooltip={
<Section
flexDirection="column"
alignItems="start"
padding={0.25}
gap={0.25}
height="auto"
>
<div className="p-1">
<Section flexDirection="column" gap={0.25} height="auto">
<div className="p-1 w-full">
<Text as="p" secondaryBody text03>
{memoryText}
</Text>
@@ -72,7 +66,6 @@ function MemoryTagWithTooltip({
icon={SvgAddLines}
title={operationLabel}
sizePreset="secondary"
paddingVariant="sm"
variant="body"
prominence="muted"
rightChildren={

View File

@@ -103,7 +103,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
<MemoriesModal
initialTargetMemoryId={memoryId}
initialTargetIndex={index}
highlightOnOpen
highlightFirstOnOpen
/>
</memoriesModal.Provider>
{memoryText ? (

View File

@@ -7,8 +7,11 @@ import { processRawChatHistory } from "@/app/app/services/lib";
import { getLatestMessageChain } from "@/app/app/services/messageTree";
import HumanMessage from "@/app/app/message/HumanMessage";
import AgentMessage from "@/app/app/message/messageComponents/AgentMessage";
import { Callout } from "@/components/ui/callout";
import OnyxInitializingLoader from "@/components/OnyxInitializingLoader";
import { Section } from "@/layouts/general-layouts";
import { IllustrationContent } from "@opal/layouts";
import SvgNotFound from "@opal/illustrations/not-found";
import { Button } from "@opal/components";
import { Persona } from "@/app/admin/agents/interfaces";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import PreviewModal from "@/sections/modals/PreviewModal";
@@ -33,12 +36,17 @@ export default function SharedChatDisplay({
if (!chatSession) {
return (
<div className="min-h-full w-full">
<div className="mx-auto w-fit pt-8">
<Callout type="danger" title="Shared Chat Not Found">
Did not find a shared chat with the specified ID.
</Callout>
</div>
<div className="h-full w-full flex flex-col items-center justify-center">
<Section flexDirection="column" alignItems="center" gap={1}>
<IllustrationContent
illustration={SvgNotFound}
title="Shared chat not found"
description="Did not find a shared chat with the specified ID."
/>
<Button href="/app" prominence="secondary">
Start a new chat
</Button>
</Section>
</div>
);
}
@@ -51,12 +59,17 @@ export default function SharedChatDisplay({
if (firstMessage === undefined) {
return (
<div className="min-h-full w-full">
<div className="mx-auto w-fit pt-8">
<Callout type="danger" title="Shared Chat Not Found">
No messages found in shared chat.
</Callout>
</div>
<div className="h-full w-full flex flex-col items-center justify-center">
<Section flexDirection="column" alignItems="center" gap={1}>
<IllustrationContent
illustration={SvgNotFound}
title="Shared chat not found"
description="No messages found in shared chat."
/>
<Button href="/app" prominence="secondary">
Start a new chat
</Button>
</Section>
</div>
);
}

View File

@@ -61,6 +61,11 @@ interface UseChatSessionControllerProps {
}) => Promise<void>;
}
export type SessionFetchError = {
type: "not_found" | "access_denied" | "unknown";
detail: string;
} | null;
export default function useChatSessionController({
existingChatSessionId,
searchParams,
@@ -80,6 +85,8 @@ export default function useChatSessionController({
const [currentSessionFileTokenCount, setCurrentSessionFileTokenCount] =
useState<number>(0);
const [projectFiles, setProjectFiles] = useState<ProjectFile[]>([]);
const [sessionFetchError, setSessionFetchError] =
useState<SessionFetchError>(null);
// Store actions
const updateSessionAndMessageTree = useChatSessionStore(
(state) => state.updateSessionAndMessageTree
@@ -151,6 +158,8 @@ export default function useChatSessionController({
}
async function initialSessionFetch() {
setSessionFetchError(null);
if (existingChatSessionId === null) {
// Clear the current session in the store to show intro messages
setCurrentSession(null);
@@ -178,9 +187,42 @@ export default function useChatSessionController({
setCurrentSession(existingChatSessionId);
setIsFetchingChatMessages(existingChatSessionId, true);
const response = await fetch(
`/api/chat/get-chat-session/${existingChatSessionId}`
);
let response: Response;
try {
response = await fetch(
`/api/chat/get-chat-session/${existingChatSessionId}`
);
} catch (error) {
setIsFetchingChatMessages(existingChatSessionId, false);
console.error("Failed to fetch chat session", {
chatSessionId: existingChatSessionId,
error,
});
setSessionFetchError({
type: "unknown",
detail: "Failed to load chat session. Please check your connection.",
});
return;
}
if (!response.ok) {
setIsFetchingChatMessages(existingChatSessionId, false);
let detail = "An unexpected error occurred.";
try {
const errorBody = await response.json();
detail = errorBody.detail || detail;
} catch {
// ignore parse errors
}
const type =
response.status === 404
? "not_found"
: response.status === 403
? "access_denied"
: "unknown";
setSessionFetchError({ type, detail });
return;
}
const session = await response.json();
const chatSession = session as BackendChatSession;
@@ -356,5 +398,6 @@ export default function useChatSessionController({
currentSessionFileTokenCount,
onMessageSelection,
projectFiles,
sessionFetchError,
};
}

View File

@@ -20,15 +20,10 @@
"use client";
import {
cn,
ensureHrefProtocol,
INTERACTIVE_SELECTOR,
noProp,
} from "@/lib/utils";
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
import type { Components } from "react-markdown";
import Text from "@/refresh-components/texts/Text";
import { useCallback, useMemo, useRef, useState, useEffect } from "react";
import { useCallback, useMemo, useState, useEffect } from "react";
import { useAppBackground } from "@/providers/AppBackgroundProvider";
import { useTheme } from "next-themes";
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
@@ -537,37 +532,6 @@ function Root({ children, enableBackground }: AppRootProps) {
const { isSafari } = useBrowserInfo();
const isLightMode = resolvedTheme === "light";
const showBackground = hasBackground && enableBackground;
// Track whether the chat input was focused before a mousedown, so we can
// restore focus on mouseup if no text was selected. This preserves
// click-drag text selection while keeping the input focused on plain clicks.
const inputWasFocused = useRef(false);
const handleMouseDown = useCallback(
(event: React.MouseEvent<HTMLDivElement>) => {
const activeEl = document.activeElement;
const isFocused =
activeEl instanceof HTMLElement &&
activeEl.id === "onyx-chat-input-textarea";
const target = event.target;
const isInteractive =
target instanceof HTMLElement && !!target.closest(INTERACTIVE_SELECTOR);
inputWasFocused.current = isFocused && !isInteractive;
},
[]
);
const handleMouseUp = useCallback(() => {
if (!inputWasFocused.current) return;
inputWasFocused.current = false;
const sel = window.getSelection();
if (sel && !sel.isCollapsed) return;
const textarea = document.getElementById("onyx-chat-input-textarea");
// Only restore focus if no other element has grabbed it since mousedown.
if (textarea && document.activeElement !== textarea) {
textarea.focus();
}
}, []);
const horizontalBlurMask = `linear-gradient(
to right,
transparent 0%,
@@ -585,8 +549,6 @@ function Root({ children, enableBackground }: AppRootProps) {
*/
<div
data-main-container
onMouseDown={handleMouseDown}
onMouseUp={handleMouseUp}
className={cn(
"@container flex flex-col h-full w-full relative overflow-hidden",
showBackground && "bg-cover bg-center bg-fixed"

View File

@@ -125,7 +125,7 @@ export const MAX_FILES_TO_SHOW = 3;
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
export const DEFAULT_AVATAR_SIZE_PX = 18;
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
export const HORIZON_DISTANCE_PX = 800;
export const LOGO_FOLDED_SIZE_PX = 24;
export const LOGO_UNFOLDED_SIZE_PX = 88;

View File

@@ -1,55 +0,0 @@
import { getUserInitials } from "@/lib/user";
describe("getUserInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getUserInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getUserInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getUserInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getUserInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getUserInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getUserInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getUserInitials(null, "alice@example.com")).toBe("AL");
});
it("returns null for empty email local part", () => {
expect(getUserInitials(null, "@example.com")).toBeNull();
});
it("uppercases the result", () => {
expect(getUserInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getUserInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
it("returns null for numeric name parts", () => {
expect(getUserInitials("Alice 1st", "x@test.com")).toBeNull();
});
it("returns null for numeric email", () => {
expect(getUserInitials(null, "42@domain.com")).toBeNull();
});
it("falls back to email when name has non-alpha chars", () => {
expect(getUserInitials("A1", "alice@example.com")).toBe("AL");
});
});

View File

@@ -128,54 +128,3 @@ export function getUserDisplayName(user: User | null): string {
// If nothing works, then fall back to anonymous user name
return "Anonymous";
}
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns `null` when no valid alpha initials can be derived.
*/
export function getUserInitials(
name: string | null,
email: string
): string | null {
if (name) {
const words = name.trim().split(/\s+/);
if (words.length >= 2) {
const first = words[0]?.[0];
const second = words[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (name.trim().length >= 1) {
const result = name.trim().slice(0, 2).toUpperCase();
if (/^[A-Z]{1,2}$/.test(result)) return result;
}
}
const local = email.split("@")[0];
if (!local || local.length === 0) return null;
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
const first = parts[0]?.[0];
const second = parts[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (local.length >= 2) {
const result = local.slice(0, 2).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
if (local.length === 1) {
const result = local.toUpperCase();
if (/^[A-Z]$/.test(result)) return result;
}
return null;
}

View File

@@ -13,9 +13,6 @@ import { ALLOWED_URL_PROTOCOLS } from "./constants";
const URI_SCHEME_REGEX = /^[a-zA-Z][a-zA-Z\d+.-]*:/;
const BARE_EMAIL_REGEX = /^[^\s@/]+@[^\s@/:]+\.[^\s@/:]+$/;
export const INTERACTIVE_SELECTOR =
"a, button, input, textarea, select, label, [role='button'], [tabindex]:not([tabindex='-1']), [contenteditable]:not([contenteditable='false'])";
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
}

View File

@@ -4,7 +4,10 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { buildImgUrl } from "@/app/app/components/files/images/utils";
import { OnyxIcon } from "@/components/icons/icons";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { DEFAULT_AVATAR_SIZE_PX, DEFAULT_AGENT_ID } from "@/lib/constants";
import {
DEFAULT_AGENT_AVATAR_SIZE_PX,
DEFAULT_AGENT_ID,
} from "@/lib/constants";
import CustomAgentAvatar from "@/refresh-components/avatars/CustomAgentAvatar";
import Image from "next/image";
@@ -15,7 +18,7 @@ export interface AgentAvatarProps {
export default function AgentAvatar({
agent,
size = DEFAULT_AVATAR_SIZE_PX,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
...props
}: AgentAvatarProps) {
const settings = useSettingsContext();

View File

@@ -4,7 +4,7 @@ import { cn } from "@/lib/utils";
import type { IconProps } from "@opal/types";
import Text from "@/refresh-components/texts/Text";
import Image from "next/image";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { DEFAULT_AGENT_AVATAR_SIZE_PX } from "@/lib/constants";
import {
SvgActivitySmall,
SvgAudioEqSmall,
@@ -96,7 +96,7 @@ export default function CustomAgentAvatar({
src,
iconName,
size = DEFAULT_AVATAR_SIZE_PX,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
}: CustomAgentAvatarProps) {
if (src) {
return (

View File

@@ -1,52 +0,0 @@
import { SvgUser } from "@opal/icons";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { getUserInitials } from "@/lib/user";
import Text from "@/refresh-components/texts/Text";
import type { User } from "@/lib/types";
export interface UserAvatarProps {
user: User;
size?: number;
}
export default function UserAvatar({
user,
size = DEFAULT_AVATAR_SIZE_PX,
}: UserAvatarProps) {
const initials = getUserInitials(
user.personalization?.name ?? null,
user.email
);
if (!initials) {
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-tint-01"
style={{ width: size, height: size }}
>
<SvgUser size={size * 0.55} className="stroke-text-03" aria-hidden />
</div>
);
}
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
style={{ width: size, height: size }}
>
<Text
inverted
secondaryAction
text05
className="select-none"
style={{ fontSize: size * 0.4 }}
>
{initials}
</Text>
</div>
);
}

View File

@@ -54,10 +54,8 @@ function MemoryItem({
const wrapperRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (shouldFocus && textareaRef.current) {
const el = textareaRef.current;
el.focus();
el.selectionStart = el.selectionEnd = el.value.length;
if (shouldFocus) {
textareaRef.current?.focus();
onFocused?.();
}
}, [shouldFocus, onFocused]);
@@ -65,10 +63,8 @@ function MemoryItem({
useEffect(() => {
if (!shouldHighlight) return;
wrapperRef.current?.scrollIntoView({
block: "start",
behavior: "smooth",
});
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
textareaRef.current?.focus();
setIsHighlighting(true);
const timer = setTimeout(() => {
@@ -83,10 +79,10 @@ function MemoryItem({
<div
ref={wrapperRef}
className={cn(
"rounded-08 w-full p-0.5 border border-transparent",
"rounded-08 hover:bg-background-tint-00 w-full p-0.5",
"transition-colors ",
isHighlighting &&
"bg-action-link-01 hover:bg-action-link-01 border-action-link-05 duration-700"
"bg-action-link-01 border border-action-link-05 duration-700"
)}
>
<Section gap={0.25} alignItems="start">
@@ -104,7 +100,7 @@ function MemoryItem({
rows={3}
maxLength={MAX_MEMORY_LENGTH}
resizable={false}
className="bg-background-tint-01 hover:bg-background-tint-00 focus-within:bg-background-tint-00"
className={cn(!isFocused && "bg-transparent")}
/>
<Disabled disabled={!memory.content.trim() && memory.isNew}>
<Button
@@ -126,29 +122,13 @@ function MemoryItem({
);
}
function resolveTargetMemoryId(
targetMemoryId: number | null | undefined,
targetIndex: number | null | undefined,
memories: MemoryItem[]
): number | null {
if (targetMemoryId != null) return targetMemoryId;
if (targetIndex != null && memories.length > 0) {
// Backend index is ASC (oldest-first), frontend displays DESC (newest-first)
const descIdx = memories.length - 1 - targetIndex;
return memories[descIdx]?.id ?? null;
}
return null;
}
interface MemoriesModalProps {
memories?: MemoryItem[];
onSaveMemories?: (memories: MemoryItem[]) => Promise<boolean>;
onClose?: () => void;
initialTargetMemoryId?: number | null;
initialTargetIndex?: number | null;
highlightOnOpen?: boolean;
highlightFirstOnOpen?: boolean;
}
export default function MemoriesModal({
@@ -157,7 +137,7 @@ export default function MemoriesModal({
onClose,
initialTargetMemoryId,
initialTargetIndex,
highlightOnOpen = false,
highlightFirstOnOpen = false,
}: MemoriesModalProps) {
const close = useModalClose(onClose);
const [focusMemoryId, setFocusMemoryId] = useState<number | null>(null);
@@ -202,16 +182,24 @@ export default function MemoriesModal({
);
useEffect(() => {
const targetId = resolveTargetMemoryId(
initialTargetMemoryId,
initialTargetIndex,
effectiveMemories
);
if (targetId == null) return;
setFocusMemoryId(targetId);
if (highlightOnOpen) {
setHighlightMemoryId(targetId);
if (initialTargetMemoryId != null) {
// Direct DB id available — use it
setHighlightMemoryId(initialTargetMemoryId);
} else if (initialTargetIndex != null && effectiveMemories.length > 0) {
// Backend index is ASC (oldest-first), but the frontend displays DESC
// (newest-first). Convert: descIdx = totalCount - 1 - ascIdx
const descIdx = effectiveMemories.length - 1 - initialTargetIndex;
const target = effectiveMemories[descIdx];
if (target) {
setHighlightMemoryId(target.id);
}
} else if (
highlightFirstOnOpen &&
effectiveMemories.length > 0 &&
effectiveMemories[0]
) {
// Fallback: highlight the first displayed item (newest)
setHighlightMemoryId(effectiveMemories[0].id);
}
}, [initialTargetMemoryId, initialTargetIndex]);

View File

@@ -5,6 +5,7 @@ import { personaIncludesRetrieval } from "@/app/app/services/lib";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast, useToastFromQuery } from "@/hooks/useToast";
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
import { Section } from "@/layouts/general-layouts";
import { useFederatedConnectors, useFilters, useLlmManager } from "@/lib/hooks";
import { useForcedTools } from "@/lib/hooks/useForcedTools";
import OnyxInitializingLoader from "@/components/OnyxInitializingLoader";
@@ -62,6 +63,9 @@ import { useShowOnboarding } from "@/hooks/useShowOnboarding";
import * as AppLayouts from "@/layouts/app-layouts";
import { SvgChevronDown, SvgFileText } from "@opal/icons";
import { Button } from "@opal/components";
import { IllustrationContent } from "@opal/layouts";
import SvgNotFound from "@opal/illustrations/not-found";
import SvgNoAccess from "@opal/illustrations/no-access";
import Spacer from "@/refresh-components/Spacer";
import useAppFocus from "@/hooks/useAppFocus";
import { useQueryController } from "@/providers/QueryControllerProvider";
@@ -381,23 +385,26 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
setSelectedAgentFromId,
});
const { onMessageSelection, currentSessionFileTokenCount } =
useChatSessionController({
existingChatSessionId: currentChatSessionId,
searchParams,
filterManager,
firstMessage,
setSelectedAgentFromId,
setSelectedDocuments,
setCurrentMessageFiles,
chatSessionIdRef,
loadedIdSessionRef,
chatInputBarRef,
isInitialLoad,
submitOnLoadPerformed,
refreshChatSessions,
onSubmit,
});
const {
onMessageSelection,
currentSessionFileTokenCount,
sessionFetchError,
} = useChatSessionController({
existingChatSessionId: currentChatSessionId,
searchParams,
filterManager,
firstMessage,
setSelectedAgentFromId,
setSelectedDocuments,
setCurrentMessageFiles,
chatSessionIdRef,
loadedIdSessionRef,
chatInputBarRef,
isInitialLoad,
submitOnLoadPerformed,
refreshChatSessions,
onSubmit,
});
useSendMessageToParent();
@@ -679,7 +686,10 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
{/* ChatUI */}
<Fade
show={
appFocus.isChat() && !!currentChatSessionId && !!liveAgent
appFocus.isChat() &&
!!currentChatSessionId &&
!!liveAgent &&
!sessionFetchError
}
className="h-full w-full flex flex-col items-center"
>
@@ -708,6 +718,45 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
</ChatScrollContainer>
</Fade>
{/* Session fetch error (404 / 403) */}
<Fade
show={appFocus.isChat() && sessionFetchError !== null}
className="h-full w-full flex flex-col items-center justify-center"
>
{sessionFetchError && (
<Section
flexDirection="column"
alignItems="center"
gap={1}
>
<IllustrationContent
illustration={
sessionFetchError.type === "access_denied"
? SvgNoAccess
: SvgNotFound
}
title={
sessionFetchError.type === "not_found"
? "Chat not found"
: sessionFetchError.type === "access_denied"
? "Access denied"
: "Something went wrong"
}
description={
sessionFetchError.type === "not_found"
? "This chat session doesn't exist or has been deleted."
: sessionFetchError.type === "access_denied"
? "You don't have permission to view this chat session."
: sessionFetchError.detail
}
/>
<Button href="/app" prominence="secondary">
Start a new chat
</Button>
</Section>
)}
</Fade>
{/* ProjectUI */}
{appFocus.isProject() && (
<div className="w-full max-h-[50vh] overflow-y-auto overscroll-y-none">
@@ -736,7 +785,12 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
</div>
{/* ── Middle-center: AppInputBar ── */}
<div className="row-start-2 flex flex-col items-center px-4">
<div
className={cn(
"row-start-2 flex flex-col items-center px-4",
sessionFetchError && "hidden"
)}
>
<div className="relative w-full max-w-[var(--app-page-main-content-width)] flex flex-col">
{/* Scroll to bottom button - positioned absolutely above AppInputBar */}
{appFocus.isChat() && showScrollButton && (

View File

@@ -26,8 +26,7 @@ import type {
StatusFilter,
StatusCountMap,
} from "./interfaces";
import UserAvatar from "@/refresh-components/avatars/UserAvatar";
import type { User } from "@/lib/types";
import { getInitials } from "./utils";
// ---------------------------------------------------------------------------
// Column renderers
@@ -76,25 +75,20 @@ const tc = createTableColumns<UserRow>();
function buildColumns(onMutate: () => void) {
return [
tc.qualifier({
content: "icon",
getContent: (row) => {
const user = {
email: row.email,
personalization: row.personal_name
? { name: row.personal_name }
: undefined,
} as User;
return (props) => <UserAvatar user={user} size={props.size} />;
},
content: "avatar-user",
getInitials: (row) => getInitials(row.personal_name, row.email),
selectable: false,
}),
tc.column("email", {
header: "Name",
weight: 22,
minWidth: 140,
cell: renderNameColumn,
}),
tc.column("groups", {
header: "Groups",
weight: 24,
minWidth: 200,
enableSorting: false,
cell: (value, row) => (
<GroupsCell groups={value} user={row} onMutate={onMutate} />
@@ -103,16 +97,19 @@ function buildColumns(onMutate: () => void) {
tc.column("role", {
header: "Account Type",
weight: 16,
minWidth: 180,
cell: (_value, row) => <UserRoleCell user={row} onMutate={onMutate} />,
}),
tc.column("status", {
header: "Status",
weight: 14,
minWidth: 100,
cell: renderStatusColumn,
}),
tc.column("updated_at", {
header: "Last Updated",
weight: 14,
minWidth: 100,
cell: renderLastUpdatedColumn,
}),
tc.actions({
@@ -222,6 +219,7 @@ export default function UsersTable({
data={filteredUsers}
columns={columns}
getRowId={(row) => row.id ?? row.email}
qualifier="avatar"
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
emptyState={

View File

@@ -0,0 +1,43 @@
import { getInitials } from "./utils";
describe("getInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getInitials(null, "alice@example.com")).toBe("AL");
});
it("returns ? for empty email local part", () => {
expect(getInitials(null, "@example.com")).toBe("?");
});
it("uppercases the result", () => {
expect(getInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
});

View File

@@ -0,0 +1,23 @@
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns at most 2 uppercase characters.
*/
export function getInitials(name: string | null, email: string): string {
if (name) {
const parts = name.trim().split(/\s+/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return name.slice(0, 2).toUpperCase();
}
const local = email.split("@")[0];
if (!local) return "?";
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return local.slice(0, 2).toUpperCase();
}

View File

@@ -0,0 +1,54 @@
import { test, expect } from "@playwright/test";
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
const NON_EXISTENT_CHAT_ID = "00000000-0000-0000-0000-000000000000";
for (const theme of THEMES) {
test.describe(`Chat session not found (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("should show 404 page for a non-existent chat session", async ({
page,
}) => {
await page.goto(`/app?chatId=${NON_EXISTENT_CHAT_ID}`);
await expect(page.getByText("Chat not found")).toBeVisible({
timeout: 10000,
});
await expect(
page.getByText("This chat session doesn't exist or has been deleted.")
).toBeVisible();
await expect(
page.getByRole("link", { name: "Start a new chat" })
).toBeVisible();
// Sidebar should still be visible
await expect(page.getByTestId("AppSidebar/new-session")).toBeVisible();
const container = page.locator("[data-main-container]");
await expect(container).toBeVisible();
await expectElementScreenshot(container, {
name: `chat-session-not-found-${theme}`,
});
});
test("should navigate to /app when clicking Start a new chat", async ({
page,
}) => {
await page.goto(`/app?chatId=${NON_EXISTENT_CHAT_ID}`);
await expect(page.getByText("Chat not found")).toBeVisible({
timeout: 10000,
});
await page.getByRole("link", { name: "Start a new chat" }).click();
await page.waitForLoadState("networkidle");
await expect(page).toHaveURL("/app");
await expect(page.getByText("Chat not found")).toBeHidden();
});
});
}

View File

@@ -1,47 +0,0 @@
import { test, expect } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.describe(`Chat Input Focus Retention`, () => {
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
await page.goto("/app");
await page.waitForLoadState("networkidle");
});
test("clicking empty space retains focus on chat input", async ({ page }) => {
const textarea = page.locator("#onyx-chat-input-textarea");
await textarea.waitFor({ state: "visible", timeout: 10000 });
// Focus the textarea and type something
await textarea.focus();
await textarea.fill("test message");
await expect(textarea).toBeFocused();
// Click on the main container's empty space (top-left corner)
const container = page.locator("[data-main-container]");
await container.click({ position: { x: 10, y: 10 } });
// Focus should remain on the textarea
await expect(textarea).toBeFocused();
});
test("clicking interactive elements still moves focus away", async ({
page,
}) => {
const textarea = page.locator("#onyx-chat-input-textarea");
await textarea.waitFor({ state: "visible", timeout: 10000 });
// Focus the textarea
await textarea.focus();
await expect(textarea).toBeFocused();
// Click on an interactive element inside the container
const button = page.locator("[data-main-container] button").first();
await button.waitFor({ state: "visible", timeout: 5000 });
await button.click();
// Focus should have moved away from the textarea
await expect(textarea).not.toBeFocused();
});
});