mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-24 09:02:43 +00:00
Compare commits
1 Commits
fix/chat-s
...
jamison/40
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfdeb65bbb |
@@ -25,6 +25,9 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
@@ -55,7 +58,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
@@ -71,7 +74,9 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
num_available_tenants = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
|
||||
num_minimum_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
if num_available_tenants < num_minimum_available_tenants:
|
||||
@@ -93,12 +98,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
try:
|
||||
lock_check.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release check lock (likely expired), continuing"
|
||||
)
|
||||
lock_check.release()
|
||||
|
||||
|
||||
def pre_provision_tenant() -> None:
|
||||
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
@@ -185,9 +185,4 @@ def pre_provision_tenant() -> None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
try:
|
||||
lock_provision.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release provision lock (likely expired), continuing"
|
||||
)
|
||||
lock_provision.release()
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
if not MULTI_TENANT and HOOK_ENABLED:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
|
||||
@@ -30,8 +30,6 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import plaintext_file_name_for_id
|
||||
from onyx.file_store.utils import store_plaintext
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
@@ -291,33 +289,6 @@ def process_kg_commands(
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
def _get_or_extract_plaintext(
|
||||
file_id: str,
|
||||
extract_fn: Callable[[], str],
|
||||
) -> str:
|
||||
"""Load cached plaintext for a file, or extract and store it.
|
||||
|
||||
Tries to read pre-stored plaintext from the file store. On a miss,
|
||||
calls extract_fn to produce the text, then stores the result so
|
||||
future calls skip the expensive extraction.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
plaintext_key = plaintext_file_name_for_id(file_id)
|
||||
|
||||
# Try cached plaintext first.
|
||||
try:
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
if content_text:
|
||||
store_plaintext(file_id, content_text)
|
||||
return content_text
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -332,23 +303,12 @@ def load_chat_file(
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
file_id = file_descriptor["id"]
|
||||
|
||||
def _extract() -> str:
|
||||
return extract_file_text(
|
||||
try:
|
||||
content_text = extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
# Use the user_file_id as cache key when available (matches what
|
||||
# the celery indexing worker stores), otherwise fall back to the
|
||||
# file store id (covers code-interpreter-generated files, etc.).
|
||||
user_file_id_str = file_descriptor.get("user_file_id")
|
||||
cache_key = user_file_id_str or file_id
|
||||
|
||||
try:
|
||||
content_text = _get_or_extract_plaintext(cache_key, _extract)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
|
||||
@@ -177,8 +177,8 @@ class ExtractedContextFiles(BaseModel):
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
project_id_filter: int | None
|
||||
persona_id_filter: int | None
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
|
||||
@@ -399,13 +399,13 @@ def determine_search_params(
|
||||
"""
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
persona_id_filter = persona_id
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
project_id_filter = project_id
|
||||
search_project_id = project_id
|
||||
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
@@ -418,8 +418,8 @@ def determine_search_params(
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
@@ -474,18 +474,11 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
@@ -718,8 +711,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -69,13 +70,9 @@ class BaseFilters(BaseModel):
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
# Scopes search to user files tagged with a given project/persona in Vespa.
|
||||
# These are NOT simply the IDs of the current project or persona — they are
|
||||
# only set when the persona's/project's user files overflowed the LLM
|
||||
# context window and must be searched via vector DB instead of being loaded
|
||||
# directly into the prompt.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -38,8 +39,9 @@ logger = setup_logger()
|
||||
def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session | None = None,
|
||||
@@ -95,6 +97,16 @@ def _build_index_filters(
|
||||
if not source_filter and detected_source_filter:
|
||||
source_filter = detected_source_filter
|
||||
|
||||
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
|
||||
# source type is included in the filter, otherwise user files will be excluded!
|
||||
if user_file_ids and source_filter:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# Add user_file to the source filter if not already present
|
||||
if DocumentSource.USER_FILE not in source_filter:
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
@@ -105,8 +117,9 @@ def _build_index_filters(
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -252,16 +265,19 @@ def search_pipeline(
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't fit
|
||||
# in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None = None,
|
||||
persona_id_filter: int | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
)
|
||||
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
@@ -286,8 +302,9 @@ def search_pipeline(
|
||||
filters = _build_index_filters(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -110,6 +110,7 @@ def search_chunks(
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -54,17 +53,9 @@ def get_chat_session_by_id(
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
is_shared: bool = False,
|
||||
eager_load_persona: bool = False,
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
|
||||
if eager_load_persona:
|
||||
stmt = stmt.options(
|
||||
selectinload(ChatSession.persona).selectinload(Persona.tools),
|
||||
selectinload(ChatSession.persona).selectinload(Persona.user_files),
|
||||
selectinload(ChatSession.project),
|
||||
)
|
||||
|
||||
if is_shared:
|
||||
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
|
||||
else:
|
||||
|
||||
@@ -2,7 +2,6 @@ import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -150,9 +149,6 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
|
||||
@@ -10,8 +10,8 @@ How `IndexFilters` fields combine into the final query filter. Applies to both V
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
@@ -31,22 +31,12 @@ AND time >= cutoff -- if set
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### Primary vs additive triggers
|
||||
|
||||
- **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit
|
||||
knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is
|
||||
NOT the raw ID of the persona being used — it is only set when the persona's
|
||||
user files overflowed the LLM context window.
|
||||
- **`project_id_filter`** is **additive**. It widens an existing scope to include project
|
||||
files but never restricts on its own — a chat inside a project should still search
|
||||
team knowledge when no other knowledge is attached.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None:
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id_filter` is ignored — it never restricts on its own.
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
@@ -54,40 +44,39 @@ When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only persona user files (overflowed context)
|
||||
AND (personas contains 42)
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + persona user files
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing project files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter:
|
||||
|
||||
```
|
||||
-- Document sets + project files overflowed
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR user_project contains 7
|
||||
)
|
||||
|
||||
-- Persona user files + project files (won't happen in practice;
|
||||
-- custom personas ignore project files per the precedence rule)
|
||||
AND (
|
||||
personas contains 42
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id_filter (no explicit knowledge)
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
@@ -102,10 +91,11 @@ AND (acl contains ...)
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `persona_id_filter` | `personas` | `array<int>` | Persona tag for overflowing user files (**primary** trigger) |
|
||||
| `project_id_filter` | `user_project` | `array<int>` | Project tag for overflowing project files (**additive** only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
@@ -218,8 +219,9 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -284,8 +286,9 @@ class DocumentQuery:
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
project_id_filter=None,
|
||||
persona_id_filter=None,
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -353,8 +356,9 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -445,8 +449,9 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -524,8 +529,9 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -585,8 +591,9 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -817,8 +824,9 @@ class DocumentQuery:
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -849,12 +857,12 @@ class DocumentQuery:
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
project_id_filter: If not None, only documents with this project ID
|
||||
in user projects will be retrieved. Additive — only applied
|
||||
when a knowledge scope already exists.
|
||||
persona_id_filter: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved. Primary — creates
|
||||
a knowledge scope on its own.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -871,6 +879,10 @@ class DocumentQuery:
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
attached_document_ids: Document IDs explicitly attached to the
|
||||
assistant. If provided along with hierarchy_node_ids, documents
|
||||
matching EITHER criteria will be retrieved (OR logic).
|
||||
@@ -931,6 +943,15 @@ class DocumentQuery:
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
@@ -1031,17 +1052,14 @@ class DocumentQuery:
|
||||
# assistant can see. When none are set the assistant searches
|
||||
# everything.
|
||||
#
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
# project_id / persona_id are additive: they make overflowing user files
|
||||
# findable but must NOT trigger the restriction on their own (an agent
|
||||
# with no explicit knowledge should search everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
or persona_id_filter is not None
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
@@ -1056,17 +1074,23 @@ class DocumentQuery:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
if persona_id_filter is not None:
|
||||
# Additive: widen scope to also cover overflowing user files, but
|
||||
# only when an explicit restriction is already in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id_filter)
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if project_id_filter is not None:
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id_filter)
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
@@ -1084,6 +1108,8 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
@@ -199,29 +199,31 @@ def build_vespa_filters(
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments restrict what an
|
||||
# assistant can see. When none are set, the assistant can see
|
||||
# everything.
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter))
|
||||
|
||||
# project_id_filter only widens an existing scope.
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(
|
||||
knowledge_scope_parts,
|
||||
_build_user_project_filter(filters.project_id_filter),
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
|
||||
@@ -88,7 +88,6 @@ class OnyxErrorCode(Enum):
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
|
||||
@@ -38,7 +38,17 @@ def get_federated_retrieval_functions(
|
||||
source_types: list[DocumentSource] | None,
|
||||
document_set_names: list[str] | None,
|
||||
slack_context: SlackContext | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
) -> list[FederatedRetrievalInfo]:
|
||||
# When User Knowledge (user files) is the only knowledge source enabled,
|
||||
# skip federated connectors entirely. User Knowledge mode means the agent
|
||||
# should ONLY use uploaded files, not team connectors like Slack.
|
||||
if user_file_ids and not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping all federated connectors: User Knowledge mode enabled "
|
||||
f"with {len(user_file_ids)} user files and no document sets"
|
||||
)
|
||||
return []
|
||||
|
||||
# Check for Slack bot context first (regardless of user_id)
|
||||
if slack_context:
|
||||
|
||||
@@ -23,55 +23,45 @@ from onyx.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def plaintext_file_name_for_id(file_id: str) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a file."""
|
||||
return f"plaintext_{file_id}"
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
|
||||
|
||||
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""
|
||||
Store plaintext content for a file in the file store.
|
||||
Store plaintext content for a user file in the file store.
|
||||
|
||||
Args:
|
||||
file_id: The ID of the file (user_file or artifact_file)
|
||||
user_file_id: The ID of the user file
|
||||
plaintext_content: The plaintext content to store
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
plaintext_file_name = plaintext_file_name_for_id(file_id)
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for {file_id}",
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Convenience wrappers for callers that use user-file UUIDs ---
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return plaintext_file_name_for_id(str(user_file_id))
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
|
||||
return store_plaintext(str(user_file_id), plaintext_content)
|
||||
|
||||
|
||||
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
|
||||
"""Load a file directly from the file store using its file_record ID.
|
||||
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
"""No active hook configured for this hook point."""
|
||||
|
||||
|
||||
class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
)
|
||||
return outcome.response_payload
|
||||
@@ -42,8 +42,12 @@ class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = None
|
||||
timeout_seconds: float | None = Field(default=None, gt=0)
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
@@ -56,14 +60,6 @@ class HookUpdateRequest(BaseModel):
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
|
||||
raise ValueError(
|
||||
"fail_strategy cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
|
||||
raise ValueError(
|
||||
"timeout_seconds cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -94,28 +90,38 @@ class HookResponse(BaseModel):
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
is_reachable: bool | None
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateStatus(str, Enum):
|
||||
passed = "passed" # server responded (any status except 401/403)
|
||||
auth_failed = "auth_failed" # server responded with 401 or 403
|
||||
timeout = (
|
||||
"timeout" # TCP connected, but read/write timed out (server exists but slow)
|
||||
)
|
||||
cannot_connect = "cannot_connect" # could not connect to the server
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
status: HookValidateStatus
|
||||
success: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class HookExecutionRecord(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -14,25 +13,22 @@ _REQUIRED_ATTRS = (
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
"payload_model",
|
||||
"response_model",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec:
|
||||
class HookPointSpec(ABC):
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
This is NOT a regular class meant for direct instantiation by callers.
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
|
||||
get_hook_point_spec() or get_all_specs() from the registry over direct
|
||||
instantiation.
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
|
||||
should ever create instances directly — use get_hook_point_spec() or
|
||||
get_all_specs() from the registry instead.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
payload_model and response_model must be Pydantic BaseModel subclasses;
|
||||
input_schema and output_schema are derived from them automatically.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
@@ -43,33 +39,21 @@ class HookPointSpec:
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
payload_model: ClassVar[type[BaseModel]]
|
||||
response_model: ClassVar[type[BaseModel]]
|
||||
|
||||
# Computed once at class definition time from payload_model / response_model.
|
||||
input_schema: ClassVar[dict[str, Any]]
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
for attr in ("payload_model", "response_model"):
|
||||
val = getattr(cls, attr, None)
|
||||
if val is None or not (
|
||||
isinstance(val, type) and issubclass(val, BaseModel)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
|
||||
)
|
||||
cls.input_schema = cls.payload_model.model_json_schema()
|
||||
cls.output_schema = cls.response_model.model_json_schema()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
@@ -27,5 +18,12 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@@ -1,39 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(description="The raw query string exactly as the user typed it.")
|
||||
user_email: str | None = Field(
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
query: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
@@ -66,5 +37,47 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
"chat_session_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email", "chat_session_id"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -454,7 +453,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -1,453 +0,0 @@
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
from onyx.db.hook import get_hook_execution_logs
|
||||
from onyx.db.hook import get_hooks
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookExecutionRecord
|
||||
from onyx.hooks.models import HookPointMetaResponse
|
||||
from onyx.hooks.models import HookResponse
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
|
||||
return HookResponse(
|
||||
id=hook.id,
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
is_reachable=hook.is_reachable,
|
||||
creator_email=(
|
||||
creator_email
|
||||
if creator_email is not None
|
||||
else (hook.creator.email if hook.creator else None)
|
||||
),
|
||||
created_at=hook.created_at,
|
||||
updated_at=hook.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_hook_or_404(
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
include_creator=include_creator,
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
|
||||
return hook
|
||||
|
||||
|
||||
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
|
||||
"""Raise an appropriate OnyxError for a non-passed validation result."""
|
||||
if validation.status == HookValidateStatus.auth_failed:
|
||||
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
|
||||
if validation.status == HookValidateStatus.timeout:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_endpoint(
|
||||
endpoint_url: str,
|
||||
api_key: str | None,
|
||||
timeout_seconds: float,
|
||||
) -> HookValidateResponse:
|
||||
"""Check whether endpoint_url is reachable by sending an empty POST request.
|
||||
|
||||
We use POST since hook endpoints expect POST requests. The server will typically
|
||||
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
|
||||
the server is up and routable. A 401/403 response returns auth_failed
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
|
||||
response = client.post(endpoint_url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed,
|
||||
error_message=f"Authentication failed (HTTP {response.status_code})",
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.timeout,
|
||||
error_message="Endpoint timed out — consider increasing timeout_seconds.",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connection error for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
HookPointMetaResponse(
|
||||
hook_point=spec.hook_point,
|
||||
display_name=spec.display_name,
|
||||
description=spec.description,
|
||||
docs_url=spec.docs_url,
|
||||
input_schema=spec.input_schema,
|
||||
output_schema=spec.output_schema,
|
||||
default_timeout_seconds=spec.default_timeout_seconds,
|
||||
default_fail_strategy=spec.default_fail_strategy,
|
||||
fail_hard_description=spec.fail_hard_description,
|
||||
)
|
||||
for spec in get_all_specs()
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
hooks = get_hooks(db_session=db_session, include_creator=True)
|
||||
return [_hook_to_response(h) for h in hooks]
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = create_hook__no_commit(
|
||||
db_session=db_session,
|
||||
name=req.name,
|
||||
hook_point=req.hook_point,
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.patch("/{hook_id}")
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
|
||||
endpoint is re-validated using the effective values. For active hooks the update is
|
||||
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
|
||||
the update goes through regardless and is_reachable is updated to reflect the result.
|
||||
|
||||
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
|
||||
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
|
||||
"""
|
||||
# api_key: UNSET = no change, None = clear, value = update
|
||||
api_key: str | None | UnsetType
|
||||
if "api_key" not in req.model_fields_set:
|
||||
api_key = UNSET
|
||||
elif req.api_key is None:
|
||||
api_key = None
|
||||
else:
|
||||
api_key = req.api_key.get_secret_value()
|
||||
|
||||
endpoint_url_changing = "endpoint_url" in req.model_fields_set
|
||||
api_key_changing = not isinstance(api_key, UnsetType)
|
||||
timeout_changing = "timeout_seconds" in req.model_fields_set
|
||||
|
||||
validated_is_reachable: bool | None = None
|
||||
if endpoint_url_changing or api_key_changing or timeout_changing:
|
||||
existing = _get_hook_or_404(db_session, hook_id)
|
||||
effective_url: str = (
|
||||
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
|
||||
)
|
||||
effective_api_key: str | None = (
|
||||
(api_key if not isinstance(api_key, UnsetType) else None)
|
||||
if api_key_changing
|
||||
else (
|
||||
existing.api_key.get_value(apply_mask=False)
|
||||
if existing.api_key
|
||||
else None
|
||||
)
|
||||
)
|
||||
effective_timeout: float = (
|
||||
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
)
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=effective_url,
|
||||
api_key=effective_api_key,
|
||||
timeout_seconds=effective_timeout,
|
||||
)
|
||||
if existing.is_active and validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
validated_is_reachable = validation.status == HookValidateStatus.passed
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
name=req.name,
|
||||
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds,
|
||||
is_reachable=validated_is_reachable,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
# Persist is_reachable=False in a separate session so the request
|
||||
# session has no commits on the failure path and the transaction
|
||||
# boundary stays clean.
|
||||
if hook.is_reachable is not False:
|
||||
with get_session_with_current_tenant() as side_session:
|
||||
update_hook__no_commit(
|
||||
db_session=side_session, hook_id=hook_id, is_reachable=False
|
||||
)
|
||||
side_session.commit()
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
validation_passed = validation.status == HookValidateStatus.passed
|
||||
if hook.is_reachable != validation_passed:
|
||||
update_hook__no_commit(
|
||||
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
|
||||
)
|
||||
db_session.commit()
|
||||
return validation
|
||||
|
||||
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=False,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution log endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{hook_id}/execution-logs")
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
_get_hook_or_404(db_session, hook_id)
|
||||
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
|
||||
return [
|
||||
HookExecutionRecord(
|
||||
error_message=log.error_message,
|
||||
status_code=log.status_code,
|
||||
duration_ms=log.duration_ms,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
@@ -53,12 +53,8 @@ logger = setup_logger()
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
# Vespa metadata filters for overflowing user files. These are NOT the
|
||||
# IDs of the current project/persona — they are only set when the
|
||||
# project's/persona's user files didn't fit in the LLM context window and
|
||||
# must be found via vector DB search instead.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -184,8 +180,8 @@ def construct_tools(
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -432,8 +428,8 @@ def construct_tools(
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -764,7 +764,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
tags=None,
|
||||
access_control_list=access_control_list,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
project_id_filter=None,
|
||||
user_file_ids=None,
|
||||
project_id=None,
|
||||
)
|
||||
|
||||
def _merge_indexed_and_crawled_results(
|
||||
|
||||
@@ -244,11 +244,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't
|
||||
# fit in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None = None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -262,8 +261,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id_filter = project_id_filter
|
||||
self.persona_id_filter = persona_id_filter
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -452,15 +451,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
# For projects, the search scope is the project and has no other limits
|
||||
user_selected_filters=(
|
||||
self.user_selected_filters
|
||||
if self.project_id_filter is None
|
||||
else None
|
||||
self.user_selected_filters if self.project_id is None else None
|
||||
),
|
||||
bypass_acl=self.bypass_acl,
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id_filter=self.project_id_filter,
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
@@ -577,7 +574,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
# Federated retrieval functions (non-Slack; Slack is separate)
|
||||
if self.project_id_filter is not None:
|
||||
if self.project_id is not None:
|
||||
# Project mode ignores user filters → no federated sources
|
||||
prefetch_source_types = None
|
||||
else:
|
||||
@@ -590,12 +587,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
user_file_ids = (
|
||||
[uf.id for uf in self.persona.user_files] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=persona_document_sets,
|
||||
user_file_ids=user_file_ids,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -140,20 +140,10 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(
|
||||
url: str,
|
||||
*,
|
||||
allow_private_network: bool = False,
|
||||
https_only: bool = False,
|
||||
) -> str:
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
allow_private_network: If True, skip private/reserved IP checks.
|
||||
https_only: If True, reject http:// URLs (only https:// is allowed).
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
@@ -167,12 +157,7 @@ def validate_outbound_http_url(
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if https_only:
|
||||
if parsed.scheme != "https":
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
|
||||
)
|
||||
elif parsed.scheme not in ("http", "https"):
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.models import Persona
|
||||
|
||||
|
||||
def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
|
||||
"""Verify that eager_load_persona pre-loads persona, its collections, and project."""
|
||||
persona = Persona(name="eager-load-test", description="test")
|
||||
db_session.add(persona)
|
||||
db_session.flush()
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="test",
|
||||
user_id=None,
|
||||
persona_id=persona.id,
|
||||
)
|
||||
|
||||
loaded = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
unloaded = inspect(loaded).unloaded
|
||||
assert "persona" not in unloaded
|
||||
assert "project" not in unloaded
|
||||
|
||||
persona_unloaded = inspect(loaded.persona).unloaded
|
||||
assert "tools" not in persona_unloaded
|
||||
assert "user_files" not in persona_unloaded
|
||||
|
||||
db_session.rollback()
|
||||
@@ -1,30 +1,34 @@
|
||||
"""Tests for OpenSearch assistant knowledge filter construction.
|
||||
|
||||
These tests verify that when an assistant (persona) has knowledge attached,
|
||||
the search filter includes the appropriate scope filters with OR logic (not AND),
|
||||
ensuring documents are discoverable across knowledge types like attached documents,
|
||||
hierarchy nodes, document sets, and persona/project user files.
|
||||
These tests verify that when an assistant (persona) has user files attached,
|
||||
the search filter includes those user file IDs in the assistant knowledge filter
|
||||
with OR logic (not AND), ensuring user files are discoverable alongside other
|
||||
knowledge types like attached documents and hierarchy nodes.
|
||||
|
||||
This prevents a regression where user_file_ids were added as a separate AND
|
||||
filter, making it impossible to find user files when the assistant also had
|
||||
attached documents or hierarchy nodes (since no document could match both).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
USER_FILE_ID = UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b")
|
||||
ATTACHED_DOCUMENT_ID = "https://docs.google.com/document/d/test-doc-id"
|
||||
HIERARCHY_NODE_ID = 42
|
||||
PERSONA_ID = 7
|
||||
|
||||
|
||||
def _get_search_filters(
|
||||
source_types: list[DocumentSource],
|
||||
user_file_ids: list[UUID],
|
||||
attached_document_ids: list[str] | None,
|
||||
hierarchy_node_ids: list[int] | None,
|
||||
persona_id_filter: int | None = None,
|
||||
document_sets: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return DocumentQuery._get_search_filters(
|
||||
tenant_state=TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False),
|
||||
@@ -32,14 +36,15 @@ def _get_search_filters(
|
||||
access_control_list=["user_email:test@example.com"],
|
||||
source_types=source_types,
|
||||
tags=[],
|
||||
document_sets=document_sets or [],
|
||||
project_id_filter=None,
|
||||
persona_id_filter=persona_id_filter,
|
||||
document_sets=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=None,
|
||||
user_file_ids=user_file_ids,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
)
|
||||
@@ -48,97 +53,137 @@ def _get_search_filters(
|
||||
class TestAssistantKnowledgeFilter:
|
||||
"""Tests for assistant knowledge filter construction in OpenSearch queries."""
|
||||
|
||||
def test_persona_id_filter_added_when_knowledge_scope_exists(self) -> None:
|
||||
"""persona_id_filter should be OR'd into the knowledge scope filter
|
||||
when explicit knowledge attachments (attached_document_ids,
|
||||
hierarchy_node_ids, document_sets) are present."""
|
||||
def test_user_file_ids_included_in_assistant_knowledge_filter(self) -> None:
|
||||
"""
|
||||
Tests that user_file_ids are included in the assistant knowledge filter
|
||||
with OR logic when the assistant has both user files and attached documents.
|
||||
|
||||
This prevents the regression where user files were ANDed with other
|
||||
knowledge types, making them unfindable.
|
||||
"""
|
||||
|
||||
# Under test: Call the filter construction method directly
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[DocumentSource.FILE],
|
||||
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
|
||||
user_file_ids=[USER_FILE_ID],
|
||||
attached_document_ids=[ATTACHED_DOCUMENT_ID],
|
||||
hierarchy_node_ids=[HIERARCHY_NODE_ID],
|
||||
persona_id_filter=PERSONA_ID,
|
||||
)
|
||||
|
||||
# Postcondition: Find the assistant knowledge filter (bool with should clauses)
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
# Check if this is the knowledge filter (has minimum_should_match=1)
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
assert knowledge_filter is not None, (
|
||||
"Expected to find an assistant knowledge filter with "
|
||||
"'minimum_should_match: 1'"
|
||||
)
|
||||
assert (
|
||||
knowledge_filter is not None
|
||||
), "Expected to find an assistant knowledge filter with 'minimum_should_match: 1'"
|
||||
|
||||
# The knowledge filter should have 3 should clauses (user files, attached docs, hierarchy nodes)
|
||||
should_clauses = knowledge_filter["bool"]["should"]
|
||||
persona_found = any(
|
||||
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
|
||||
== PERSONA_ID
|
||||
for clause in should_clauses
|
||||
)
|
||||
assert persona_found, (
|
||||
f"Expected persona_id={PERSONA_ID} filter on {PERSONAS_FIELD_NAME} "
|
||||
f"in should clauses. Got: {should_clauses}"
|
||||
assert (
|
||||
len(should_clauses) == 3
|
||||
), f"Expected 3 should clauses (user_file, attached_doc, hierarchy_node), got {len(should_clauses)}"
|
||||
|
||||
# Verify user_file_id is in one of the should clauses
|
||||
user_file_filter_found = False
|
||||
for should_clause in should_clauses:
|
||||
# The user file filter uses a nested bool with should for each file ID
|
||||
if "bool" in should_clause and "should" in should_clause["bool"]:
|
||||
for term_clause in should_clause["bool"]["should"]:
|
||||
if "term" in term_clause:
|
||||
term_value = term_clause["term"].get(DOCUMENT_ID_FIELD_NAME, {})
|
||||
if term_value.get("value") == str(USER_FILE_ID):
|
||||
user_file_filter_found = True
|
||||
break
|
||||
|
||||
assert user_file_filter_found, (
|
||||
f"Expected user_file_id {USER_FILE_ID} to be in the assistant knowledge "
|
||||
f"filter's should clauses. Filter structure: {knowledge_filter}"
|
||||
)
|
||||
|
||||
def test_persona_id_filter_alone_creates_knowledge_scope(self) -> None:
|
||||
"""persona_id_filter IS a primary knowledge scope trigger — a persona
|
||||
with user files is explicit knowledge, so it should restrict
|
||||
search on its own."""
|
||||
def test_user_file_ids_only_creates_knowledge_filter(self) -> None:
|
||||
"""
|
||||
Tests that when only user_file_ids are provided (no attached_documents or
|
||||
hierarchy_nodes), the assistant knowledge filter is still created with the
|
||||
user file IDs.
|
||||
"""
|
||||
# Precondition
|
||||
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[],
|
||||
source_types=[DocumentSource.USER_FILE],
|
||||
user_file_ids=[USER_FILE_ID],
|
||||
attached_document_ids=None,
|
||||
hierarchy_node_ids=None,
|
||||
persona_id_filter=PERSONA_ID,
|
||||
)
|
||||
|
||||
knowledge_filter = None
|
||||
# Postcondition: Find filter that contains our user file ID
|
||||
user_file_filter_found = False
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
knowledge_filter = clause
|
||||
break
|
||||
clause_str = str(clause)
|
||||
if str(USER_FILE_ID) in clause_str:
|
||||
user_file_filter_found = True
|
||||
break
|
||||
|
||||
assert (
|
||||
knowledge_filter is not None
|
||||
), "Expected persona_id_filter alone to create a knowledge scope filter"
|
||||
persona_found = any(
|
||||
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
|
||||
== PERSONA_ID
|
||||
for clause in knowledge_filter["bool"]["should"]
|
||||
)
|
||||
assert persona_found, (
|
||||
f"Expected persona_id={PERSONA_ID} filter in knowledge scope. "
|
||||
f"Got: {knowledge_filter}"
|
||||
)
|
||||
user_file_filter_found
|
||||
), f"Expected user_file_id {USER_FILE_ID} to be in the filter clauses. Got: {filter_clauses}"
|
||||
|
||||
def test_no_separate_user_file_filter_when_assistant_has_knowledge(self) -> None:
|
||||
"""
|
||||
Tests that user_file_ids are NOT added as a separate AND filter when the
|
||||
assistant has other knowledge attached (attached_documents or hierarchy_nodes).
|
||||
"""
|
||||
|
||||
def test_knowledge_filter_with_document_sets_and_persona_filter(self) -> None:
|
||||
"""document_sets and persona_id_filter should be OR'd together in
|
||||
the knowledge scope filter."""
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[],
|
||||
attached_document_ids=None,
|
||||
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
|
||||
user_file_ids=[USER_FILE_ID],
|
||||
attached_document_ids=[ATTACHED_DOCUMENT_ID],
|
||||
hierarchy_node_ids=None,
|
||||
persona_id_filter=PERSONA_ID,
|
||||
document_sets=["engineering"],
|
||||
)
|
||||
|
||||
knowledge_filter = None
|
||||
# Postcondition: Count how many times user_file_id appears in filter clauses
|
||||
# It should appear exactly once (in the knowledge filter), not twice
|
||||
user_file_id_str = str(USER_FILE_ID)
|
||||
occurrences = 0
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
knowledge_filter = clause
|
||||
break
|
||||
if user_file_id_str in str(clause):
|
||||
occurrences += 1
|
||||
|
||||
assert (
|
||||
knowledge_filter is not None
|
||||
), "Expected knowledge filter when document_sets is provided"
|
||||
assert occurrences == 1, (
|
||||
f"Expected user_file_id to appear exactly once in filter clauses "
|
||||
f"(inside the assistant knowledge filter), but found {occurrences} "
|
||||
f"occurrences. This suggests user_file_ids is being added as both a "
|
||||
f"separate AND filter and inside the knowledge filter. "
|
||||
f"Filter clauses: {filter_clauses}"
|
||||
)
|
||||
|
||||
filter_str = str(knowledge_filter)
|
||||
assert (
|
||||
"engineering" in filter_str
|
||||
), "Expected document_set 'engineering' in knowledge filter"
|
||||
assert (
|
||||
str(PERSONA_ID) in filter_str
|
||||
), f"Expected persona_id_filter {PERSONA_ID} in knowledge filter"
|
||||
def test_multiple_user_files_all_included_in_filter(self) -> None:
|
||||
"""
|
||||
Tests that when multiple user files are attached to an assistant,
|
||||
all of them are included in the filter.
|
||||
"""
|
||||
# Precondition
|
||||
user_file_ids = [
|
||||
UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b"),
|
||||
UUID("7be95f56-5561-517d-ae47-acd6f85bdb7c"),
|
||||
UUID("8cf06a67-6672-628e-bf58-ade7a96cec8d"),
|
||||
]
|
||||
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[DocumentSource.USER_FILE],
|
||||
user_file_ids=user_file_ids,
|
||||
attached_document_ids=[ATTACHED_DOCUMENT_ID],
|
||||
hierarchy_node_ids=None,
|
||||
)
|
||||
|
||||
# Postcondition: All user file IDs should be in the filter
|
||||
filter_str = str(filter_clauses)
|
||||
for user_file_id in user_file_ids:
|
||||
assert (
|
||||
str(user_file_id) in filter_str
|
||||
), f"Expected user_file_id {user_file_id} to be in the filter clauses"
|
||||
|
||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
@@ -29,9 +28,6 @@ _BACKEND_DIR = os.path.normpath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
|
||||
)
|
||||
|
||||
_DROP_SCHEMA_MAX_RETRIES = 3
|
||||
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -54,39 +50,6 @@ def _run_script(
|
||||
)
|
||||
|
||||
|
||||
def _force_drop_schema(engine: Engine, schema: str) -> None:
|
||||
"""Terminate backends using *schema* then drop it, retrying on deadlock.
|
||||
|
||||
Background Celery workers may discover test schemas (they match the
|
||||
``tenant_`` prefix) and hold locks on tables inside them. A bare
|
||||
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
|
||||
first kill their connections and retry if we still hit a deadlock.
|
||||
"""
|
||||
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT pg_terminate_backend(l.pid)
|
||||
FROM pg_locks l
|
||||
JOIN pg_class c ON c.oid = l.relation
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = :schema
|
||||
AND l.pid != pg_backend_pid()
|
||||
"""
|
||||
),
|
||||
{"schema": schema},
|
||||
)
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
return
|
||||
except Exception:
|
||||
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
|
||||
raise
|
||||
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,7 +104,9 @@ def tenant_schema_at_head(
|
||||
|
||||
yield schema
|
||||
|
||||
_force_drop_schema(engine, schema)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -158,7 +123,9 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
_force_drop_schema(engine, schema)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -183,7 +150,9 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
_force_drop_schema(engine, schema)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -141,12 +139,12 @@ def test_chat_history_csv_export(
|
||||
assert headers["Content-Type"] == "text/csv; charset=utf-8"
|
||||
assert "Content-Disposition" in headers
|
||||
|
||||
# Use csv.reader to properly handle newlines inside quoted fields
|
||||
csv_rows = list(csv.reader(io.StringIO(csv_content)))
|
||||
assert len(csv_rows) == 3 # Header + 2 QA pairs
|
||||
assert csv_rows[0][0] == "chat_session_id"
|
||||
assert "user_message" in csv_rows[0]
|
||||
assert "ai_response" in csv_rows[0]
|
||||
# Verify CSV content
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 3 # Header + 2 QA pairs
|
||||
assert "chat_session_id" in csv_content
|
||||
assert "user_message" in csv_content
|
||||
assert "ai_response" in csv_content
|
||||
assert "What was the Q1 revenue?" in csv_content
|
||||
assert "What about Q2 revenue?" in csv_content
|
||||
|
||||
@@ -158,5 +156,5 @@ def test_chat_history_csv_export(
|
||||
end_time=past_end,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
csv_rows = list(csv.reader(io.StringIO(csv_content)))
|
||||
assert len(csv_rows) == 1 # Only header, no data rows
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 1 # Only header, no data rows
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
@@ -10,10 +11,12 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
|
||||
|
||||
class IncompleteSpec(HookPointSpec):
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
# missing display_name, description, payload_model, response_model, etc.
|
||||
# missing display_name, description, etc.
|
||||
|
||||
class _Payload(BaseModel):
|
||||
pass
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
payload_model = _Payload
|
||||
response_model = _Payload
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@@ -1,541 +0,0 @@
|
||||
"""Unit tests for the hook executor."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
|
||||
|
||||
|
||||
def _make_hook(
|
||||
*,
|
||||
is_active: bool = True,
|
||||
endpoint_url: str | None = "https://hook.example.com/query",
|
||||
api_key: MagicMock | None = None,
|
||||
timeout_seconds: float = 5.0,
|
||||
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
|
||||
hook_id: int = 1,
|
||||
is_reachable: bool | None = None,
|
||||
) -> MagicMock:
|
||||
hook = MagicMock()
|
||||
hook.is_active = is_active
|
||||
hook.endpoint_url = endpoint_url
|
||||
hook.api_key = api_key
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
hook.id = hook_id
|
||||
hook.fail_strategy = fail_strategy
|
||||
hook.is_reachable = is_reachable
|
||||
return hook
|
||||
|
||||
|
||||
def _make_api_key(value: str) -> MagicMock:
|
||||
api_key = MagicMock()
|
||||
api_key.get_value.return_value = value
|
||||
return api_key
|
||||
|
||||
|
||||
def _make_response(
|
||||
*,
|
||||
status_code: int = 200,
|
||||
json_return: Any = _RESPONSE_PAYLOAD,
|
||||
json_side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a response mock with controllable json() behaviour."""
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
if json_side_effect is not None:
|
||||
response.json.side_effect = json_side_effect
|
||||
else:
|
||||
response.json.return_value = json_return
|
||||
return response
|
||||
|
||||
|
||||
def _setup_client(
|
||||
mock_client_cls: MagicMock,
|
||||
*,
|
||||
response: MagicMock | None = None,
|
||||
side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Wire up the httpx.Client mock and return the inner client.
|
||||
|
||||
If side_effect is an httpx.HTTPStatusError, it is raised from
|
||||
raise_for_status() (matching real httpx behaviour) and post() returns a
|
||||
response mock with the matching status_code set. All other exceptions are
|
||||
raised directly from post().
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
if isinstance(side_effect, httpx.HTTPStatusError):
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = side_effect.response.status_code
|
||||
error_response.raise_for_status.side_effect = side_effect
|
||||
mock_client.post = MagicMock(return_value=error_response)
|
||||
else:
|
||||
mock_client.post = MagicMock(
|
||||
side_effect=side_effect, return_value=response if not side_effect else None
|
||||
)
|
||||
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Early-exit guards (no HTTP call, no DB writes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hooks_available,hook",
|
||||
[
|
||||
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(False, None, id="hooks_not_available"),
|
||||
pytest.param(True, None, id="hook_not_found"),
|
||||
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
],
|
||||
)
|
||||
def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session: MagicMock,
|
||||
hooks_available: bool,
|
||||
hook: MagicMock | None,
|
||||
) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSkipped)
|
||||
mock_update.assert_not_called()
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Successful HTTP call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
_, update_kwargs = mock_update.call_args
|
||||
assert update_kwargs["is_reachable"] is True
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
|
||||
"""Deduplication guard: a hook already at is_reachable=True that succeeds
|
||||
must not trigger a DB write."""
|
||||
hook = _make_hook(is_reachable=True)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(json_return=["unexpected", "list"]),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-dict" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() raising must be treated as failure with SOFT strategy.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(
|
||||
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
|
||||
),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-JSON" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP failure paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception,fail_strategy,expected_type,expected_is_reachable",
|
||||
[
|
||||
# NetworkError → is_reachable=False
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="connect_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="connect_error_hard",
|
||||
),
|
||||
# 401/403 → is_reachable=False (api_key revoked)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"401",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=401, text="Unauthorized"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="auth_401_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"403",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=403, text="Forbidden"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="auth_403_hard",
|
||||
),
|
||||
# TimeoutException → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="timeout_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="timeout_hard",
|
||||
),
|
||||
# Other HTTP errors → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="http_status_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="http_status_error_hard",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_http_failure_paths(
|
||||
db_session: MagicMock,
|
||||
exception: Exception,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
expected_is_reachable: bool | None,
|
||||
) -> None:
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=exception)
|
||||
|
||||
if expected_type is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert isinstance(result, expected_type)
|
||||
|
||||
if expected_is_reachable is None:
|
||||
mock_update.assert_not_called()
|
||||
else:
|
||||
mock_update.assert_called_once()
|
||||
_, kwargs = mock_update.call_args
|
||||
assert kwargs["is_reachable"] is expected_is_reachable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization header
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_key_value,expect_auth_header",
|
||||
[
|
||||
pytest.param("secret-token", True, id="api_key_present"),
|
||||
pytest.param(None, False, id="api_key_absent"),
|
||||
],
|
||||
)
|
||||
def test_authorization_header(
|
||||
db_session: MagicMock,
|
||||
api_key_value: str | None,
|
||||
expect_auth_header: bool,
|
||||
) -> None:
|
||||
api_key = _make_api_key(api_key_value) if api_key_value else None
|
||||
hook = _make_hook(api_key=api_key)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
mock_client = _setup_client(mock_client_cls, response=_make_response())
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
_, call_kwargs = mock_client.post.call_args
|
||||
if expect_auth_header:
|
||||
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
|
||||
else:
|
||||
assert "Authorization" not in call_kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist session failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"http_exception,expected_result",
|
||||
[
|
||||
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
|
||||
],
|
||||
)
|
||||
def test_persist_session_failure_is_swallowed(
|
||||
db_session: MagicMock,
|
||||
http_exception: Exception | None,
|
||||
expected_result: Any,
|
||||
) -> None:
|
||||
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_session_with_current_tenant",
|
||||
side_effect=RuntimeError("DB unavailable"),
|
||||
),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response() if not http_exception else None,
|
||||
side_effect=http_exception,
|
||||
)
|
||||
|
||||
if expected_result is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
|
||||
"""is_reachable update failing (e.g. concurrent hook deletion) must not
|
||||
prevent the execution log from being written.
|
||||
|
||||
Simulates the production failure path: update_hook__no_commit raises
|
||||
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
|
||||
between the initial lookup and the reachable update.
|
||||
"""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.hooks.executor.update_hook__no_commit",
|
||||
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
|
||||
),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
mock_log.assert_called_once()
|
||||
@@ -37,20 +37,18 @@ def test_input_schema_query_is_string() -> None:
|
||||
|
||||
def test_input_schema_user_email_is_nullable() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
# Pydantic v2 emits anyOf for nullable fields
|
||||
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
|
||||
assert "null" in props["user_email"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_optional() -> None:
|
||||
# query defaults to None (absent = reject); not required in the schema
|
||||
def test_output_schema_query_is_required() -> None:
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "query" not in schema.get("required", [])
|
||||
assert "query" in schema["required"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_nullable() -> None:
|
||||
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
|
||||
# null means "reject the query"
|
||||
props = QueryProcessingSpec().output_schema["properties"]
|
||||
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
|
||||
assert "null" in props["query"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_rejection_message_is_optional() -> None:
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Unit tests for onyx.server.features.hooks.api helpers.
|
||||
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from onyx.server.features.hooks.api import _validate_endpoint
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_URL = "https://example.com/hook"
|
||||
_API_KEY = "secret"
|
||||
_TIMEOUT = 5.0
|
||||
|
||||
|
||||
def _mock_response(status_code: int) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_ssrf_safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSsrfSafety:
|
||||
def _call(self, url: str) -> None:
|
||||
_check_ssrf_safety(url)
|
||||
|
||||
# --- scheme checks ---
|
||||
|
||||
def test_https_is_allowed(self) -> None:
|
||||
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
|
||||
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
|
||||
self._call("https://example.com/hook") # must not raise
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url", ["http://example.com/hook", "ftp://example.com/hook"]
|
||||
)
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
pytest.param("127.0.0.1", id="loopback"),
|
||||
pytest.param("10.0.0.1", id="RFC1918-A"),
|
||||
pytest.param("172.16.0.1", id="RFC1918-B"),
|
||||
pytest.param("192.168.1.1", id="RFC1918-C"),
|
||||
pytest.param("169.254.169.254", id="link-local-IMDS"),
|
||||
pytest.param("100.64.0.1", id="shared-address-space"),
|
||||
pytest.param("::1", id="IPv6-loopback"),
|
||||
pytest.param("fc00::1", id="IPv6-ULA"),
|
||||
pytest.param("fe80::1", id="IPv6-link-local"),
|
||||
],
|
||||
)
|
||||
def test_private_ip_is_blocked(self, ip: str) -> None:
|
||||
with (
|
||||
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
|
||||
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
|
||||
self._call("https://example.com/hook") # must not raise
|
||||
|
||||
def test_dns_resolution_failure_raises(self) -> None:
|
||||
import socket
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.utils.url.socket.getaddrinfo",
|
||||
side_effect=socket.gaierror("name not found"),
|
||||
),
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEndpoint:
|
||||
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
|
||||
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
|
||||
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
return _validate_endpoint(
|
||||
endpoint_url=_URL,
|
||||
api_key=api_key,
|
||||
timeout_seconds=_TIMEOUT,
|
||||
)
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(200)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(500)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
def test_401_403_returns_auth_failed(
|
||||
self, mock_client_cls: MagicMock, status_code: int
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(status_code)
|
||||
)
|
||||
result = self._call()
|
||||
assert result.status == HookValidateStatus.auth_failed
|
||||
assert str(status_code) in (result.error_message or "")
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(422)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
httpx.ReadTimeout("read timeout"),
|
||||
httpx.WriteTimeout("write timeout"),
|
||||
httpx.PoolTimeout("pool timeout"),
|
||||
],
|
||||
)
|
||||
def test_read_write_pool_timeout_returns_timeout(
|
||||
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_error_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
# Covers DNS failures, TLS errors, and other connection-level errors.
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectError("name resolution failed")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_arbitrary_exception_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
ConnectionRefusedError("refused")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
self._call(api_key="mykey")
|
||||
_, kwargs = mock_post.call_args
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
self._call(api_key=None)
|
||||
_, kwargs = mock_post.call_args
|
||||
assert "Authorization" not in kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _raise_for_validation_failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRaiseForValidationFailure:
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected_code",
|
||||
[
|
||||
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
|
||||
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
|
||||
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
|
||||
],
|
||||
)
|
||||
def test_raises_correct_error_code(
|
||||
self, status: HookValidateStatus, expected_code: OnyxErrorCode
|
||||
) -> None:
|
||||
validation = HookValidateResponse(status=status, error_message="some error")
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.error_code == expected_code
|
||||
|
||||
def test_auth_failed_passes_error_message_directly(self) -> None:
|
||||
validation = HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed, error_message="bad credentials"
|
||||
)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.detail == "bad credentials"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
|
||||
)
|
||||
def test_timeout_and_cannot_connect_wrap_error_message(
|
||||
self, status: HookValidateStatus
|
||||
) -> None:
|
||||
validation = HookValidateResponse(status=status, error_message="raw error")
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.detail == "Endpoint validation failed: raw error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HookValidateStatus enum string values (API contract)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHookValidateStatusValues:
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected",
|
||||
[
|
||||
(HookValidateStatus.passed, "passed"),
|
||||
(HookValidateStatus.auth_failed, "auth_failed"),
|
||||
(HookValidateStatus.timeout, "timeout"),
|
||||
(HookValidateStatus.cannot_connect, "cannot_connect"),
|
||||
],
|
||||
)
|
||||
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
|
||||
assert status == expected
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
@@ -10,10 +11,10 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
||||
build_vespa_filters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
@@ -150,30 +151,56 @@ class TestBuildVespaFilters:
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_project_filter(self) -> None:
|
||||
"""Test user project filtering.
|
||||
def test_user_file_ids_filter(self) -> None:
|
||||
"""Test user file IDs filtering."""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
id2 = UUID("00000000-0000-0000-0000-000000000456")
|
||||
|
||||
project_id_filter alone does NOT trigger a knowledge scope restriction
|
||||
(an agent with no explicit knowledge should search everything).
|
||||
It only participates when explicit knowledge filters are present.
|
||||
"""
|
||||
# project_id_filter alone → no restriction
|
||||
filters = IndexFilters(access_control_list=[], project_id_filter=789)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
# project_id_filter with document_set → both OR'd
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], project_id_filter=789, document_set=["set1"]
|
||||
)
|
||||
# Single user file ID (UUID)
|
||||
filters = IndexFilters(access_control_list=[], user_file_ids=[id1])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "set1") or ({USER_PROJECT} contains "789")) and '
|
||||
f'!({HIDDEN}=true) and ({DOCUMENT_ID} contains "{str(id1)}") and ' == result
|
||||
)
|
||||
|
||||
# Multiple user file IDs (UUIDs)
|
||||
filters = IndexFilters(access_control_list=[], user_file_ids=[id1, id2])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and ({DOCUMENT_ID} contains "{str(id1)}" or {DOCUMENT_ID} contains "{str(id2)}") and '
|
||||
== result
|
||||
)
|
||||
|
||||
# No project id filter
|
||||
filters = IndexFilters(access_control_list=[], project_id_filter=None)
|
||||
# Empty user file IDs
|
||||
filters = IndexFilters(access_control_list=[], user_file_ids=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_project_filter(self) -> None:
|
||||
"""Test user project filtering.
|
||||
|
||||
project_id alone does NOT trigger a knowledge scope restriction
|
||||
(an agent with no explicit knowledge should search everything).
|
||||
It only participates when explicit knowledge filters are present.
|
||||
"""
|
||||
# project_id alone → no restriction
|
||||
filters = IndexFilters(access_control_list=[], project_id=789)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
# project_id with user_file_ids → both OR'd
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], project_id=789, user_file_ids=[id1]
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
|
||||
== result
|
||||
)
|
||||
|
||||
# No project id
|
||||
filters = IndexFilters(access_control_list=[], project_id=None)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
@@ -206,16 +233,17 @@ class TestBuildVespaFilters:
|
||||
def test_combined_filters(self) -> None:
|
||||
"""Test combining multiple filter types.
|
||||
|
||||
Knowledge-scope filters (document_set, project_id_filter, persona_id_filter)
|
||||
are OR'd together, while all other filters are AND'd.
|
||||
Knowledge-scope filters (document_set, user_file_ids, project_id,
|
||||
persona_id) are OR'd together, while all other filters are AND'd.
|
||||
"""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=["user1", "group1"],
|
||||
source_type=[DocumentSource.WEB],
|
||||
tags=[Tag(tag_key="color", tag_value="red")],
|
||||
document_set=["set1"],
|
||||
project_id_filter=789,
|
||||
persona_id_filter=42,
|
||||
user_file_ids=[id1],
|
||||
project_id=789,
|
||||
time_cutoff=datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
@@ -226,10 +254,9 @@ class TestBuildVespaFilters:
|
||||
expected += f'({SOURCE_TYPE} contains "web") and '
|
||||
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
# Knowledge scope filters are OR'd together
|
||||
# (persona_id_filter is primary, project_id_filter is additive — order reflects this)
|
||||
expected += (
|
||||
f'(({DOCUMENT_SETS} contains "set1")'
|
||||
f' or ({PERSONAS} contains "42")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f' or ({USER_PROJECT} contains "789")'
|
||||
f") and "
|
||||
)
|
||||
@@ -249,37 +276,18 @@ class TestBuildVespaFilters:
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
|
||||
|
||||
def test_persona_id_filter_is_primary_knowledge_scope(self) -> None:
|
||||
"""persona_id_filter alone should trigger a knowledge scope restriction
|
||||
(a persona with user files IS explicit knowledge)."""
|
||||
filters = IndexFilters(access_control_list=[], persona_id_filter=42)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({PERSONAS} contains "42") and ' == result
|
||||
|
||||
def test_persona_id_filter_with_project_id_filter(self) -> None:
|
||||
"""When persona_id_filter triggers the scope, project_id_filter should be
|
||||
OR'd in additively."""
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], persona_id_filter=42, project_id_filter=789
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
expected = (
|
||||
f"!({HIDDEN}=true) and "
|
||||
f'(({PERSONAS} contains "42") or ({USER_PROJECT} contains "789")) and '
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
def test_knowledge_scope_document_set_and_persona_filter_ored(self) -> None:
|
||||
"""Document set filter and persona_id_filter must be OR'd so that
|
||||
connector documents (in the set) and persona user files can
|
||||
both be found."""
|
||||
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
|
||||
"""Document set filter and user file IDs must be OR'd so that
|
||||
connector documents (in the set) and user files (with specific
|
||||
IDs) can both be found."""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
document_set=["engineering"],
|
||||
persona_id_filter=42,
|
||||
user_file_ids=[id1],
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({PERSONAS} contains "42")) and '
|
||||
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({DOCUMENT_ID} contains "{str(id1)}")) and '
|
||||
assert expected == result
|
||||
|
||||
def test_acl_large_list_uses_weighted_set(self) -> None:
|
||||
|
||||
@@ -489,18 +489,20 @@ services:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
|
||||
@@ -290,20 +290,25 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
# sleep a little bit to allow the web_server / api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -314,19 +314,21 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/sslcerts:/etc/nginx/sslcerts
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -333,20 +333,25 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
# sleep a little bit to allow the web_server / api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -202,18 +202,20 @@ services:
|
||||
ports:
|
||||
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
|
||||
@@ -477,10 +477,7 @@ services:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
# Mount templates read-only; the startup command copies them into
|
||||
# the writable /etc/nginx/conf.d/ inside the container. This avoids
|
||||
# "Permission denied" errors on Windows Docker bind mounts.
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
|
||||
# - ../data/certbot/conf:/etc/letsencrypt
|
||||
# - ../data/certbot/www:/var/www/certbot
|
||||
@@ -492,13 +489,12 @@ services:
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not receive any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
# PRODUCTION: Change to app.conf.template.prod for production nginx config
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
cache:
|
||||
image: redis:7.4-alpine
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Usage: .\install.ps1 [OPTIONS]
|
||||
# Remote (with params):
|
||||
# & ([scriptblock]::Create((irm https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.ps1))) -Lite -NoPrompt
|
||||
# Remote (defaults only, configure via interaction during script):
|
||||
# Remote (defaults only):
|
||||
# irm https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.ps1 | iex
|
||||
|
||||
param(
|
||||
@@ -57,7 +57,11 @@ function Print-Step {
|
||||
}
|
||||
|
||||
function Test-Interactive {
|
||||
return -not $NoPrompt
|
||||
if ($NoPrompt) { return $false }
|
||||
try {
|
||||
if ([Console]::IsInputRedirected) { return $false }
|
||||
return $true
|
||||
} catch { return [Environment]::UserInteractive }
|
||||
}
|
||||
|
||||
function Prompt-OrDefault {
|
||||
@@ -70,8 +74,8 @@ function Prompt-OrDefault {
|
||||
|
||||
function Confirm-Action {
|
||||
param([string]$Description)
|
||||
$reply = (Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y").Trim().ToLower()
|
||||
if ($reply -match '^n') {
|
||||
$reply = Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y"
|
||||
if ($reply -match '^[Nn]') {
|
||||
Print-Warning "Skipping: $Description"
|
||||
return $false
|
||||
}
|
||||
@@ -85,12 +89,12 @@ function Prompt-VersionTag {
|
||||
Write-Host " - Type a specific tag (e.g., craft-v1.0.0)"
|
||||
$version = Prompt-OrDefault "Enter tag [default: craft-latest]" "craft-latest"
|
||||
} else {
|
||||
Write-Host " - Press Enter for edge (recommended)"
|
||||
Write-Host " - Press Enter for latest (recommended)"
|
||||
Write-Host " - Type a specific tag (e.g., v0.1.0)"
|
||||
$version = Prompt-OrDefault "Enter tag [default: edge]" "edge"
|
||||
$version = Prompt-OrDefault "Enter tag [default: latest]" "latest"
|
||||
}
|
||||
if ($script:IncludeCraftMode -and $version -eq "craft-latest") { Print-Info "Selected: craft-latest (Craft enabled)" }
|
||||
elseif ($version -eq "edge") { Print-Info "Selected: edge (latest nightly)" }
|
||||
elseif ($version -eq "latest") { Print-Info "Selected: Latest tag" }
|
||||
else { Print-Info "Selected: $version" }
|
||||
return $version
|
||||
}
|
||||
@@ -99,16 +103,16 @@ function Prompt-DeploymentMode {
|
||||
param([string]$LiteOverlayPath)
|
||||
if ($script:LiteMode) { Print-Info "Deployment mode: Lite (set via -Lite flag)"; return }
|
||||
Print-Info "Which deployment mode would you like?"
|
||||
Write-Host " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
Write-Host " 1) Standard - Full deployment with search, connectors, and RAG"
|
||||
Write-Host " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
Write-Host " LLM chat, tools, file uploads, and Projects still work"
|
||||
Write-Host " 2) Standard - Full deployment with search, connectors, and RAG"
|
||||
$modeChoice = Prompt-OrDefault "Choose a mode (1 or 2) [default: 1]" "1"
|
||||
if ($modeChoice -eq "2") {
|
||||
Print-Info "Selected: Standard mode"
|
||||
} else {
|
||||
$script:LiteMode = $true
|
||||
Print-Info "Selected: Lite mode"
|
||||
if (-not (Ensure-OnyxFile $LiteOverlayPath "$($script:GitHubRawUrl)/$($script:LiteComposeFile)" $script:LiteComposeFile)) { exit 1 }
|
||||
} else {
|
||||
Print-Info "Selected: Standard mode"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,8 +358,7 @@ function Invoke-OnyxShutdown {
|
||||
return
|
||||
}
|
||||
if (-not (Initialize-ComposeCommand)) { Print-OnyxError "Docker Compose not found."; exit 1 }
|
||||
$stopArgs = @("stop")
|
||||
$result = Invoke-Compose -AutoDetect @stopArgs
|
||||
$result = Invoke-Compose -AutoDetect stop
|
||||
if ($result -ne 0) { Print-OnyxError "Failed to stop containers"; exit 1 }
|
||||
Print-Success "Onyx containers stopped (paused)"
|
||||
}
|
||||
@@ -364,7 +367,7 @@ function Invoke-OnyxDeleteData {
|
||||
Write-Host "`n=== WARNING: This will permanently delete all Onyx data ===`n" -ForegroundColor Red
|
||||
Print-Warning "This action will remove all Onyx containers, volumes, files, and user data."
|
||||
if (Test-Interactive) {
|
||||
$confirm = Prompt-OrDefault "Type 'DELETE' to confirm" ""
|
||||
$confirm = Read-Host "Type 'DELETE' to confirm"
|
||||
if ($confirm -ne "DELETE") { Print-Info "Operation cancelled."; return }
|
||||
} else {
|
||||
Print-OnyxError "Cannot confirm destructive operation in non-interactive mode."
|
||||
@@ -372,8 +375,7 @@ function Invoke-OnyxDeleteData {
|
||||
}
|
||||
$deployDir = Join-Path $script:InstallRoot "deployment"
|
||||
if ((Test-Path (Join-Path $deployDir "docker-compose.yml")) -and (Initialize-ComposeCommand)) {
|
||||
$downArgs = @("down", "-v")
|
||||
$result = Invoke-Compose -AutoDetect @downArgs
|
||||
$result = Invoke-Compose -AutoDetect down -v
|
||||
if ($result -eq 0) { Print-Success "Containers and volumes removed" }
|
||||
else { Print-OnyxError "Failed to remove containers" }
|
||||
}
|
||||
@@ -720,7 +722,6 @@ function Invoke-WslInstall {
|
||||
# Ensure WSL2 is available
|
||||
Invoke-NativeQuiet { wsl --status }
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
if (-not (Confirm-Action "WSL2 (Windows Subsystem for Linux)")) { exit 1 }
|
||||
Print-Info "Installing WSL2..."
|
||||
try {
|
||||
$proc = Start-Process wsl -ArgumentList "--install", "--no-distribution" -Wait -PassThru -NoNewWindow
|
||||
@@ -807,7 +808,7 @@ function Main {
|
||||
|
||||
if (Test-Interactive) {
|
||||
Write-Host "`nPlease acknowledge and press Enter to continue..." -ForegroundColor Yellow
|
||||
$null = Prompt-OrDefault "" ""
|
||||
Read-Host | Out-Null
|
||||
} else {
|
||||
Write-Host "`nRunning in non-interactive mode - proceeding automatically..." -ForegroundColor Yellow
|
||||
}
|
||||
@@ -903,8 +904,8 @@ function Main {
|
||||
if ($resourceWarning) {
|
||||
Print-Warning "Onyx recommends at least $($script:ExpectedDockerRamGB)GB RAM and $($script:ExpectedDiskGB)GB disk for standard mode."
|
||||
Print-Warning "Lite mode requires less (1-4GB RAM, 8-16GB disk) but has no vector database."
|
||||
$reply = (Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y").Trim().ToLower()
|
||||
if ($reply -notmatch '^y') { Print-Info "Installation cancelled."; exit 1 }
|
||||
$reply = Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y"
|
||||
if ($reply -notmatch '^[Yy]') { Print-Info "Installation cancelled."; exit 1 }
|
||||
Print-Info "Proceeding despite resource limitations..."
|
||||
}
|
||||
|
||||
@@ -926,13 +927,22 @@ function Main {
|
||||
if ($composeVersion -ne "unknown" -and (Compare-SemVer $composeVersion "2.24.0") -lt 0) {
|
||||
Print-Warning "Docker Compose $composeVersion is older than 2.24.0 (required for env_file format)."
|
||||
Print-Info "Update Docker Desktop or install a newer Docker Compose. Installation may fail."
|
||||
$reply = (Prompt-OrDefault "Continue anyway? (Y/n)" "y").Trim().ToLower()
|
||||
if ($reply -notmatch '^y') { exit 1 }
|
||||
$reply = Prompt-OrDefault "Continue anyway? (Y/n)" "y"
|
||||
if ($reply -notmatch '^[Yy]') { exit 1 }
|
||||
}
|
||||
|
||||
$liteOverlayPath = Join-Path $deploymentDir $script:LiteComposeFile
|
||||
if ($script:LiteMode) {
|
||||
if (-not (Ensure-OnyxFile $liteOverlayPath "$($script:GitHubRawUrl)/$($script:LiteComposeFile)" $script:LiteComposeFile)) { exit 1 }
|
||||
} elseif (Test-Path $liteOverlayPath) {
|
||||
if (Test-Path (Join-Path $deploymentDir ".env")) {
|
||||
Print-Warning "Existing lite overlay found but -Lite was not passed."
|
||||
$reply = Prompt-OrDefault "Remove lite overlay and switch to standard mode? (y/N)" "n"
|
||||
if ($reply -match '^[Yy]') { Remove-Item -Force $liteOverlayPath; Print-Info "Switched to standard mode" }
|
||||
else { $script:LiteMode = $true; Print-Info "Keeping lite mode" }
|
||||
} else {
|
||||
Remove-Item -Force $liteOverlayPath
|
||||
}
|
||||
}
|
||||
|
||||
$envTemplateDest = Join-Path $deploymentDir "env.template"
|
||||
@@ -952,8 +962,7 @@ function Main {
|
||||
# Check if services are already running
|
||||
if ((Test-Path $composeDest) -and (Initialize-ComposeCommand)) {
|
||||
$running = @()
|
||||
$psArgs = @("ps", "-q")
|
||||
try { $running = @(Invoke-Compose -AutoDetect @psArgs 2>$null | Where-Object { $_ }) } catch { }
|
||||
try { $running = @(Invoke-Compose -AutoDetect ps -q 2>$null | Where-Object { $_ }) } catch { }
|
||||
if ($running.Count -gt 0) {
|
||||
Print-OnyxError "Onyx services are currently running!"
|
||||
Print-Info "Run '.\install.ps1 -Shutdown' first, then re-run this script."
|
||||
@@ -1019,12 +1028,6 @@ function Main {
|
||||
Print-Info "You can customize .env later for OAuth/SAML, AI models, domain settings, and Craft."
|
||||
}
|
||||
|
||||
# Clean up stale lite overlay if standard mode was selected
|
||||
if (-not $script:LiteMode -and (Test-Path $liteOverlayPath)) {
|
||||
Remove-Item -Force $liteOverlayPath
|
||||
Print-Info "Removed previous lite overlay (switching to standard mode)"
|
||||
}
|
||||
|
||||
# ── Step 6: Check Ports ───────────────────────────────────────────────
|
||||
Print-Step "Checking for available ports"
|
||||
$availablePort = Find-AvailablePort 3000
|
||||
@@ -1034,7 +1037,7 @@ function Main {
|
||||
Print-Success "Using port $availablePort for nginx"
|
||||
|
||||
$currentImageTag = Get-EnvFileValue -Path $envFile -Key "IMAGE_TAG"
|
||||
$useLatest = ($currentImageTag -eq "edge" -or $currentImageTag -eq "latest" -or $currentImageTag -match '^craft-')
|
||||
$useLatest = ($currentImageTag -eq "latest" -or $currentImageTag -match '^craft-')
|
||||
if ($useLatest) { Print-Info "Using '$currentImageTag' tag - will force pull and recreate containers" }
|
||||
|
||||
# For pinned version tags, re-download config files from that tag so the
|
||||
@@ -1066,9 +1069,8 @@ function Main {
|
||||
# ── Step 8: Start Services ────────────────────────────────────────────
|
||||
Print-Step "Starting Onyx services"
|
||||
Print-Info "Launching containers..."
|
||||
$upArgs = @("up", "-d")
|
||||
if ($useLatest) { $upArgs += @("--pull", "always", "--force-recreate") }
|
||||
$upResult = Invoke-Compose @upArgs
|
||||
if ($useLatest) { $upResult = Invoke-Compose up -d --pull always --force-recreate }
|
||||
else { $upResult = Invoke-Compose up -d }
|
||||
if ($upResult -ne 0) { Print-OnyxError "Failed to start Onyx services"; exit 1 }
|
||||
|
||||
# ── Step 9: Container Health ──────────────────────────────────────────
|
||||
@@ -1076,8 +1078,7 @@ function Main {
|
||||
Start-Sleep -Seconds 10
|
||||
$restartIssues = $false
|
||||
$containerIds = @()
|
||||
$psArgs = @("ps", "-q")
|
||||
try { $containerIds = @(Invoke-Compose @psArgs 2>$null | Where-Object { $_ }) } catch { }
|
||||
try { $containerIds = @(Invoke-Compose ps -q 2>$null | Where-Object { $_ }) } catch { }
|
||||
|
||||
foreach ($cid in $containerIds) {
|
||||
if ([string]::IsNullOrWhiteSpace($cid)) { continue }
|
||||
|
||||
@@ -96,8 +96,8 @@ fi
|
||||
|
||||
# When --lite is passed as a flag, lower resource thresholds early (before the
|
||||
# resource check). When lite is chosen interactively, the thresholds are adjusted
|
||||
# after the resource check has already passed with the standard thresholds —
|
||||
# which is the safer direction.
|
||||
# inside the new-deployment flow, after the resource check has already passed
|
||||
# with the standard thresholds — which is the safer direction.
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
@@ -110,6 +110,9 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
|
||||
# Build the -f flags for docker compose.
|
||||
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
|
||||
# (used by shutdown/delete-data so users don't need to remember --lite).
|
||||
# Without the argument, the lite overlay is only included when --lite was
|
||||
# explicitly passed — preventing install/start from silently staying in
|
||||
# lite mode just because the file exists on disk from a prior run.
|
||||
compose_file_args() {
|
||||
local auto_detect="${1:-false}"
|
||||
local args="-f docker-compose.yml"
|
||||
@@ -174,42 +177,34 @@ ensure_file() {
|
||||
|
||||
# --- Interactive prompt helpers ---
|
||||
is_interactive() {
|
||||
[[ "$NO_PROMPT" = false ]] && [[ -r /dev/tty ]] && [[ -w /dev/tty ]]
|
||||
}
|
||||
|
||||
read_prompt_line() {
|
||||
local prompt_text="$1"
|
||||
if ! is_interactive; then
|
||||
REPLY=""
|
||||
return
|
||||
fi
|
||||
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
|
||||
IFS= read -r REPLY < /dev/tty || REPLY=""
|
||||
}
|
||||
|
||||
read_prompt_char() {
|
||||
local prompt_text="$1"
|
||||
if ! is_interactive; then
|
||||
REPLY=""
|
||||
return
|
||||
fi
|
||||
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
|
||||
IFS= read -r -n 1 REPLY < /dev/tty || REPLY=""
|
||||
printf "\n" > /dev/tty
|
||||
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
|
||||
}
|
||||
|
||||
prompt_or_default() {
|
||||
local prompt_text="$1"
|
||||
local default_value="$2"
|
||||
read_prompt_line "$prompt_text"
|
||||
[[ -z "$REPLY" ]] && REPLY="$default_value"
|
||||
if is_interactive; then
|
||||
read -p "$prompt_text" -r REPLY
|
||||
if [[ -z "$REPLY" ]]; then
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
else
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
}
|
||||
|
||||
prompt_yn_or_default() {
|
||||
local prompt_text="$1"
|
||||
local default_value="$2"
|
||||
read_prompt_char "$prompt_text"
|
||||
[[ -z "$REPLY" ]] && REPLY="$default_value"
|
||||
if is_interactive; then
|
||||
read -p "$prompt_text" -n 1 -r
|
||||
echo ""
|
||||
if [[ -z "$REPLY" ]]; then
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
else
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
}
|
||||
|
||||
confirm_action() {
|
||||
@@ -310,8 +305,8 @@ if [ "$DELETE_DATA_MODE" = true ]; then
|
||||
echo " • All user data and documents"
|
||||
echo ""
|
||||
if is_interactive; then
|
||||
prompt_or_default "Are you sure you want to continue? Type 'DELETE' to confirm: " ""
|
||||
echo "" > /dev/tty
|
||||
read -p "Are you sure you want to continue? Type 'DELETE' to confirm: " -r
|
||||
echo ""
|
||||
if [ "$REPLY" != "DELETE" ]; then
|
||||
print_info "Operation cancelled."
|
||||
exit 0
|
||||
@@ -505,7 +500,7 @@ echo ""
|
||||
|
||||
if is_interactive; then
|
||||
echo -e "${YELLOW}${BOLD}Please acknowledge and press Enter to continue...${NC}"
|
||||
read_prompt_line ""
|
||||
read -r
|
||||
echo ""
|
||||
else
|
||||
echo -e "${YELLOW}${BOLD}Running in non-interactive mode - proceeding automatically...${NC}"
|
||||
@@ -750,48 +745,25 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
|
||||
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
|
||||
fi
|
||||
|
||||
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
|
||||
if [[ "$LITE_MODE" = false ]]; then
|
||||
print_info "Which deployment mode would you like?"
|
||||
echo ""
|
||||
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
echo " LLM chat, tools, file uploads, and Projects still work"
|
||||
echo " 2) Standard - Full deployment with search, connectors, and RAG"
|
||||
echo ""
|
||||
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
|
||||
echo ""
|
||||
|
||||
case "$REPLY" in
|
||||
2)
|
||||
print_info "Selected: Standard mode"
|
||||
;;
|
||||
*)
|
||||
LITE_MODE=true
|
||||
print_info "Selected: Lite mode"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
print_info "Deployment mode: Lite (set via --lite flag)"
|
||||
fi
|
||||
|
||||
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
|
||||
print_error "--include-craft cannot be used with Lite mode."
|
||||
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
fi
|
||||
|
||||
# Handle lite overlay file based on selected mode
|
||||
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
|
||||
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
|
||||
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed previous lite overlay (switching to standard mode)"
|
||||
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
|
||||
print_warning "Existing lite overlay found but --lite was not passed."
|
||||
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
|
||||
LITE_MODE=true
|
||||
else
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed lite overlay (switching to standard mode)"
|
||||
fi
|
||||
else
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed previous lite overlay (switching to standard mode)"
|
||||
fi
|
||||
fi
|
||||
|
||||
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
|
||||
@@ -854,22 +826,22 @@ if [ -f "$ENV_FILE" ]; then
|
||||
if [ "$REPLY" = "update" ]; then
|
||||
print_info "Update selected. Which tag would you like to deploy?"
|
||||
echo ""
|
||||
echo "• Press Enter for edge (recommended)"
|
||||
echo "• Press Enter for latest (recommended)"
|
||||
echo "• Type a specific tag (e.g., v0.1.0)"
|
||||
echo ""
|
||||
if [ "$INCLUDE_CRAFT" = true ]; then
|
||||
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
|
||||
VERSION="$REPLY"
|
||||
else
|
||||
prompt_or_default "Enter tag [default: edge]: " "edge"
|
||||
prompt_or_default "Enter tag [default: latest]: " "latest"
|
||||
VERSION="$REPLY"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
|
||||
print_info "Selected: craft-latest (Craft enabled)"
|
||||
elif [ "$VERSION" = "edge" ]; then
|
||||
print_info "Selected: edge (latest nightly)"
|
||||
elif [ "$VERSION" = "latest" ]; then
|
||||
print_info "Selected: Latest version"
|
||||
else
|
||||
print_info "Selected: $VERSION"
|
||||
fi
|
||||
@@ -921,6 +893,45 @@ else
|
||||
print_info "No existing .env file found. Setting up new deployment..."
|
||||
echo ""
|
||||
|
||||
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
|
||||
if [[ "$LITE_MODE" = false ]]; then
|
||||
print_info "Which deployment mode would you like?"
|
||||
echo ""
|
||||
echo " 1) Standard - Full deployment with search, connectors, and RAG"
|
||||
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
echo " LLM chat, tools, file uploads, and Projects still work"
|
||||
echo ""
|
||||
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
|
||||
echo ""
|
||||
|
||||
case "$REPLY" in
|
||||
2)
|
||||
LITE_MODE=true
|
||||
print_info "Selected: Lite mode"
|
||||
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
|
||||
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
|
||||
;;
|
||||
*)
|
||||
print_info "Selected: Standard mode"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
print_info "Deployment mode: Lite (set via --lite flag)"
|
||||
fi
|
||||
|
||||
# Validate lite + craft combination (could now be set interactively)
|
||||
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
|
||||
print_error "--include-craft cannot be used with Lite mode."
|
||||
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Adjust resource expectations for lite mode
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
fi
|
||||
|
||||
# Ask for version
|
||||
print_info "Which tag would you like to deploy?"
|
||||
echo ""
|
||||
@@ -931,18 +942,18 @@ else
|
||||
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
|
||||
VERSION="$REPLY"
|
||||
else
|
||||
echo "• Press Enter for edge (recommended)"
|
||||
echo "• Press Enter for latest (recommended)"
|
||||
echo "• Type a specific tag (e.g., v0.1.0)"
|
||||
echo ""
|
||||
prompt_or_default "Enter tag [default: edge]: " "edge"
|
||||
prompt_or_default "Enter tag [default: latest]: " "latest"
|
||||
VERSION="$REPLY"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
|
||||
print_info "Selected: craft-latest (Craft enabled)"
|
||||
elif [ "$VERSION" = "edge" ]; then
|
||||
print_info "Selected: edge (latest nightly)"
|
||||
elif [ "$VERSION" = "latest" ]; then
|
||||
print_info "Selected: Latest tag"
|
||||
else
|
||||
print_info "Selected: $VERSION"
|
||||
fi
|
||||
@@ -1100,15 +1111,15 @@ fi
|
||||
export HOST_PORT=$AVAILABLE_PORT
|
||||
print_success "Using port $AVAILABLE_PORT for nginx"
|
||||
|
||||
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
|
||||
# Determine if we're using the latest tag or a craft tag (both should force pull)
|
||||
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
|
||||
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
|
||||
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
USE_LATEST=true
|
||||
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
|
||||
else
|
||||
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
|
||||
print_info "Using 'latest' tag - will force pull and recreate containers"
|
||||
fi
|
||||
else
|
||||
USE_LATEST=false
|
||||
|
||||
@@ -127,7 +127,6 @@ Inputs (common):
|
||||
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
|
||||
- `postgres_username`, `postgres_password`
|
||||
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
|
||||
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
|
||||
|
||||
### `vpc`
|
||||
- Builds a VPC sized for EKS with multiple private and public subnets
|
||||
|
||||
@@ -88,8 +88,6 @@ module "waf" {
|
||||
tags = local.merged_tags
|
||||
|
||||
# WAF configuration with sensible defaults
|
||||
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
|
||||
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
|
||||
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
|
||||
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
|
||||
geo_restriction_countries = var.waf_geo_restriction_countries
|
||||
|
||||
@@ -117,18 +117,6 @@ variable "waf_rate_limit_requests_per_5_minutes" {
|
||||
default = 2000
|
||||
}
|
||||
|
||||
variable "waf_allowed_ip_cidrs" {
|
||||
type = list(string)
|
||||
description = "Optional IPv4 CIDR ranges allowed through the WAF. Leave empty to disable IP allowlisting."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "waf_common_rule_set_count_rules" {
|
||||
type = list(string)
|
||||
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "waf_api_rate_limit_requests_per_5_minutes" {
|
||||
type = number
|
||||
description = "Rate limit for API requests per 5 minutes per IP address"
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
locals {
|
||||
name = var.name
|
||||
tags = var.tags
|
||||
ip_allowlist_enabled = length(var.allowed_ip_cidrs) > 0
|
||||
managed_rule_priority = local.ip_allowlist_enabled ? 1 : 0
|
||||
}
|
||||
|
||||
resource "aws_wafv2_ip_set" "allowed_ips" {
|
||||
count = local.ip_allowlist_enabled ? 1 : 0
|
||||
|
||||
name = "${local.name}-allowed-ips"
|
||||
description = "IP allowlist for ${local.name}"
|
||||
scope = "REGIONAL"
|
||||
ip_address_version = "IPV4"
|
||||
addresses = var.allowed_ip_cidrs
|
||||
|
||||
tags = local.tags
|
||||
name = var.name
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# AWS WAFv2 Web ACL
|
||||
@@ -27,38 +13,10 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
allow {}
|
||||
}
|
||||
|
||||
dynamic "rule" {
|
||||
for_each = local.ip_allowlist_enabled ? [1] : []
|
||||
content {
|
||||
name = "BlockRequestsOutsideAllowedIPs"
|
||||
priority = 1
|
||||
|
||||
action {
|
||||
block {}
|
||||
}
|
||||
|
||||
statement {
|
||||
not_statement {
|
||||
statement {
|
||||
ip_set_reference_statement {
|
||||
arn = aws_wafv2_ip_set.allowed_ips[0].arn
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
visibility_config {
|
||||
cloudwatch_metrics_enabled = true
|
||||
metric_name = "BlockRequestsOutsideAllowedIPsMetric"
|
||||
sampled_requests_enabled = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# AWS Managed Rules - Core Rule Set
|
||||
rule {
|
||||
name = "AWSManagedRulesCommonRuleSet"
|
||||
priority = 1 + local.managed_rule_priority
|
||||
priority = 1
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -68,16 +26,6 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
managed_rule_group_statement {
|
||||
name = "AWSManagedRulesCommonRuleSet"
|
||||
vendor_name = "AWS"
|
||||
|
||||
dynamic "rule_action_override" {
|
||||
for_each = var.common_rule_set_count_rules
|
||||
content {
|
||||
name = rule_action_override.value
|
||||
action_to_use {
|
||||
count {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +39,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# AWS Managed Rules - Known Bad Inputs
|
||||
rule {
|
||||
name = "AWSManagedRulesKnownBadInputsRuleSet"
|
||||
priority = 2 + local.managed_rule_priority
|
||||
priority = 2
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -114,7 +62,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# Rate Limiting Rule
|
||||
rule {
|
||||
name = "RateLimitRule"
|
||||
priority = 3 + local.managed_rule_priority
|
||||
priority = 3
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -139,7 +87,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
for_each = length(var.geo_restriction_countries) > 0 ? [1] : []
|
||||
content {
|
||||
name = "GeoRestrictionRule"
|
||||
priority = 4 + local.managed_rule_priority
|
||||
priority = 4
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -162,7 +110,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# IP Rate Limiting
|
||||
rule {
|
||||
name = "APIRateLimitRule"
|
||||
priority = 5 + local.managed_rule_priority
|
||||
priority = 5
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -185,7 +133,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# SQL Injection Protection
|
||||
rule {
|
||||
name = "AWSManagedRulesSQLiRuleSet"
|
||||
priority = 6 + local.managed_rule_priority
|
||||
priority = 6
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -208,7 +156,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# Anonymous IP Protection
|
||||
rule {
|
||||
name = "AWSManagedRulesAnonymousIpList"
|
||||
priority = 7 + local.managed_rule_priority
|
||||
priority = 7
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
|
||||
@@ -9,18 +9,6 @@ variable "tags" {
|
||||
default = {}
|
||||
}
|
||||
|
||||
variable "allowed_ip_cidrs" {
|
||||
type = list(string)
|
||||
description = "Optional IPv4 CIDR ranges allowed to reach the application. Leave empty to disable IP allowlisting."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "common_rule_set_count_rules" {
|
||||
type = list(string)
|
||||
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "rate_limit_requests_per_5_minutes" {
|
||||
type = number
|
||||
description = "Rate limit for requests per 5 minutes per IP address"
|
||||
|
||||
6
examples/widget/package-lock.json
generated
6
examples/widget/package-lock.json
generated
@@ -3839,9 +3839,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/flatted": {
|
||||
"version": "3.4.2",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
|
||||
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz",
|
||||
"integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
|
||||
@@ -1,38 +1,33 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
|
||||
interface ActionsContainerProps {
|
||||
type: "head" | "cell";
|
||||
children: React.ReactNode;
|
||||
size?: TableSize;
|
||||
/** Pass-through click handler (e.g. stopPropagation on body cells). */
|
||||
onClick?: (e: React.MouseEvent) => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function ActionsContainer({
|
||||
type,
|
||||
children,
|
||||
size,
|
||||
onClick,
|
||||
}: ActionsContainerProps) {
|
||||
const size = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
|
||||
const Tag = type === "head" ? "th" : "td";
|
||||
|
||||
return (
|
||||
<Tag
|
||||
className="tbl-actions"
|
||||
data-type={type}
|
||||
data-size={size}
|
||||
data-size={resolvedSize}
|
||||
onClick={onClick}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-full items-center",
|
||||
type === "cell" ? "justify-end" : "justify-center"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
<div className="flex h-full items-center justify-center">{children}</div>
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
type SortingState,
|
||||
} from "@tanstack/react-table";
|
||||
import { Button, LineItemButton } from "@opal/components";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Divider from "@/refresh-components/Divider";
|
||||
@@ -21,6 +20,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
interface SortingPopoverProps<TData extends RowData = RowData> {
|
||||
table: Table<TData>;
|
||||
sorting: SortingState;
|
||||
size?: "md" | "lg";
|
||||
footerText?: string;
|
||||
ascendingLabel?: string;
|
||||
descendingLabel?: string;
|
||||
@@ -29,11 +29,11 @@ interface SortingPopoverProps<TData extends RowData = RowData> {
|
||||
function SortingPopover<TData extends RowData>({
|
||||
table,
|
||||
sorting,
|
||||
size = "lg",
|
||||
footerText,
|
||||
ascendingLabel = "Ascending",
|
||||
descendingLabel = "Descending",
|
||||
}: SortingPopoverProps<TData>) {
|
||||
const size = useTableSize();
|
||||
const [open, setOpen] = useState(false);
|
||||
const sortableColumns = table
|
||||
.getAllLeafColumns()
|
||||
@@ -158,6 +158,7 @@ function SortingPopover<TData extends RowData>({
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CreateSortingColumnOptions {
|
||||
size?: "md" | "lg";
|
||||
footerText?: string;
|
||||
ascendingLabel?: string;
|
||||
descendingLabel?: string;
|
||||
@@ -176,6 +177,7 @@ function createSortingColumn<TData>(
|
||||
<SortingPopover
|
||||
table={table}
|
||||
sorting={table.getState().sorting}
|
||||
size={options?.size}
|
||||
footerText={options?.footerText}
|
||||
ascendingLabel={options?.ascendingLabel}
|
||||
descendingLabel={options?.descendingLabel}
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
type VisibilityState,
|
||||
} from "@tanstack/react-table";
|
||||
import { Button, LineItemButton, Tag } from "@opal/components";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgColumn, SvgCheck } from "@opal/icons";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Divider from "@/refresh-components/Divider";
|
||||
@@ -20,13 +19,14 @@ import Divider from "@/refresh-components/Divider";
|
||||
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
|
||||
table: Table<TData>;
|
||||
columnVisibility: VisibilityState;
|
||||
size?: "md" | "lg";
|
||||
}
|
||||
|
||||
function ColumnVisibilityPopover<TData extends RowData>({
|
||||
table,
|
||||
columnVisibility,
|
||||
size = "lg",
|
||||
}: ColumnVisibilityPopoverProps<TData>) {
|
||||
const size = useTableSize();
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
// User-defined columns only (exclude internal qualifier/actions)
|
||||
@@ -87,7 +87,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
|
||||
// Column definition factory
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
|
||||
interface CreateColumnVisibilityColumnOptions {
|
||||
size?: "md" | "lg";
|
||||
}
|
||||
|
||||
function createColumnVisibilityColumn<TData>(
|
||||
options?: CreateColumnVisibilityColumnOptions
|
||||
): ColumnDef<TData, unknown> {
|
||||
return {
|
||||
id: "__columnVisibility",
|
||||
size: 44,
|
||||
@@ -98,6 +104,7 @@ function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
|
||||
<ColumnVisibilityPopover
|
||||
table={table}
|
||||
columnVisibility={table.getState().columnVisibility}
|
||||
size={options?.size}
|
||||
/>
|
||||
),
|
||||
cell: () => null,
|
||||
|
||||
@@ -57,10 +57,9 @@ function DragOverlayRowInner<TData>({
|
||||
<QualifierContainer key={cell.id} type="cell">
|
||||
<TableQualifier
|
||||
content={qualifierColumn.content}
|
||||
icon={qualifierColumn.getContent?.(row.original)}
|
||||
initials={qualifierColumn.getInitials?.(row.original)}
|
||||
icon={qualifierColumn.getIcon?.(row.original)}
|
||||
imageSrc={qualifierColumn.getImageSrc?.(row.original)}
|
||||
imageAlt={qualifierColumn.getImageAlt?.(row.original)}
|
||||
background={qualifierColumn.background}
|
||||
selectable={isSelectable}
|
||||
selected={isSelectable && row.getIsSelected()}
|
||||
/>
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@opal/utils";
|
||||
import { Button, Pagination, SelectButton } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgEye, SvgXCircle } from "@opal/icons";
|
||||
import type { ReactNode } from "react";
|
||||
|
||||
@@ -43,6 +45,9 @@ interface FooterSelectionModeProps {
|
||||
onPageChange: (page: number) => void;
|
||||
/** Unit label for count pagination. @default "items" */
|
||||
units?: string;
|
||||
/** Controls overall footer sizing. `"lg"` (default) or `"md"`. */
|
||||
size?: TableSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -68,6 +73,7 @@ interface FooterSummaryModeProps {
|
||||
leftExtra?: ReactNode;
|
||||
/** Unit label for the summary text, e.g. "users". */
|
||||
units?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -104,7 +110,11 @@ export default function Footer(props: FooterProps) {
|
||||
const isSmall = resolvedSize === "md";
|
||||
return (
|
||||
<div
|
||||
className="table-footer flex w-full items-center justify-between border-t border-border-01"
|
||||
className={cn(
|
||||
"table-footer",
|
||||
"flex w-full items-center justify-between border-t border-border-01",
|
||||
props.className
|
||||
)}
|
||||
data-size={resolvedSize}
|
||||
>
|
||||
{/* Left side */}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
|
||||
interface QualifierContainerProps {
|
||||
type: "head" | "cell";
|
||||
children?: React.ReactNode;
|
||||
size?: TableSize;
|
||||
/** Pass-through click handler (e.g. stopPropagation on body cells). */
|
||||
onClick?: (e: React.MouseEvent) => void;
|
||||
}
|
||||
@@ -12,9 +12,11 @@ interface QualifierContainerProps {
|
||||
export default function QualifierContainer({
|
||||
type,
|
||||
children,
|
||||
size,
|
||||
onClick,
|
||||
}: QualifierContainerProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
|
||||
const Tag = type === "head" ? "th" : "td";
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ row selection, drag-and-drop reordering, and server-side mode.
|
||||
|
||||
```tsx
|
||||
import { Table, createTableColumns } from "@opal/components";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
|
||||
interface User {
|
||||
id: string;
|
||||
@@ -19,10 +18,11 @@ interface User {
|
||||
const tc = createTableColumns<User>();
|
||||
|
||||
const columns = [
|
||||
tc.qualifier({ content: "icon", getContent: () => SvgUser }),
|
||||
tc.qualifier({ content: "avatar-user", getInitials: (r) => r.name?.[0] ?? "?" }),
|
||||
tc.column("email", {
|
||||
header: "Name",
|
||||
weight: 22,
|
||||
minWidth: 140,
|
||||
cell: (email, row) => <span>{row.name ?? email}</span>,
|
||||
}),
|
||||
tc.column("status", {
|
||||
@@ -40,7 +40,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
columns={columns}
|
||||
getRowId={(r) => r.id}
|
||||
pageSize={10}
|
||||
footer={{}}
|
||||
footer={{ mode: "summary" }}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -55,7 +55,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
| `getRowId` | `(row: TData) => string` | required | Unique row identifier |
|
||||
| `pageSize` | `number` | `10` | Rows per page (`Infinity` disables pagination) |
|
||||
| `size` | `"md" \| "lg"` | `"lg"` | Density variant |
|
||||
| `footer` | `DataTableFooterConfig` | — | Footer configuration (mode is derived from `selectionBehavior`) |
|
||||
| `footer` | `DataTableFooterConfig` | — | Footer mode (`"selection"` or `"summary"`) |
|
||||
| `initialSorting` | `SortingState` | — | Initial sort state |
|
||||
| `initialColumnVisibility` | `VisibilityState` | — | Initial column visibility |
|
||||
| `draggable` | `DataTableDraggableConfig` | — | Enable drag-and-drop reordering |
|
||||
@@ -63,6 +63,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
| `onRowClick` | `(row: TData) => void` | — | Row click handler |
|
||||
| `searchTerm` | `string` | — | Global text filter |
|
||||
| `height` | `number \| string` | — | Max scrollable height |
|
||||
| `headerBackground` | `string` | — | Sticky header background |
|
||||
| `serverSide` | `ServerSideConfig` | — | Server-side pagination/sorting/filtering |
|
||||
| `emptyState` | `ReactNode` | — | Empty state content |
|
||||
|
||||
@@ -75,8 +76,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
- `tc.displayColumn(opts)` — non-accessor custom column
|
||||
- `tc.actions(opts)` — trailing actions column with visibility/sorting popovers
|
||||
|
||||
## Footer
|
||||
## Footer Modes
|
||||
|
||||
The footer mode is derived automatically from `selectionBehavior`:
|
||||
- **Selection footer** (when `selectionBehavior` is `"single-select"` or `"multi-select"`) — shows selection count, optional view/clear buttons, count pagination
|
||||
- **Summary footer** (when `selectionBehavior` is `"no-select"` or omitted) — shows "Showing X\~Y of Z", list pagination, optional extra element
|
||||
- **`"selection"`** — shows selection count, optional view/clear buttons, count pagination
|
||||
- **`"summary"`** — shows "Showing X~Y of Z", list pagination, optional extra element
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Table, createTableColumns } from "@opal/components";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sample data
|
||||
@@ -109,14 +108,17 @@ const tc = createTableColumns<User>();
|
||||
|
||||
const columns = [
|
||||
tc.qualifier({
|
||||
content: "icon",
|
||||
getContent: () => SvgUser,
|
||||
background: true,
|
||||
content: "avatar-user",
|
||||
getInitials: (r) =>
|
||||
r.name
|
||||
.split(" ")
|
||||
.map((n) => n[0])
|
||||
.join(""),
|
||||
}),
|
||||
tc.column("name", { header: "Name", weight: 25 }),
|
||||
tc.column("email", { header: "Email", weight: 30 }),
|
||||
tc.column("role", { header: "Role", weight: 15 }),
|
||||
tc.column("status", { header: "Status", weight: 15 }),
|
||||
tc.column("name", { header: "Name", weight: 25, minWidth: 120 }),
|
||||
tc.column("email", { header: "Email", weight: 30, minWidth: 160 }),
|
||||
tc.column("role", { header: "Role", weight: 15, minWidth: 80 }),
|
||||
tc.column("status", { header: "Status", weight: 15, minWidth: 80 }),
|
||||
tc.actions(),
|
||||
];
|
||||
|
||||
@@ -140,7 +142,7 @@ export const Default: Story = {
|
||||
columns={columns}
|
||||
getRowId={(r) => r.id}
|
||||
pageSize={8}
|
||||
footer={{}}
|
||||
footer={{ mode: "summary" }}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
|
||||
interface TableCellProps
|
||||
extends WithoutStyles<React.TdHTMLAttributes<HTMLTableCellElement>> {
|
||||
children: React.ReactNode;
|
||||
size?: TableSize;
|
||||
/** Explicit pixel width for the cell. */
|
||||
width?: number;
|
||||
}
|
||||
|
||||
export default function TableCell({
|
||||
size,
|
||||
width,
|
||||
children,
|
||||
...props
|
||||
}: TableCellProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
return (
|
||||
<td
|
||||
className="tbl-cell overflow-hidden"
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
|
||||
|
||||
@@ -12,15 +9,20 @@ import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
|
||||
|
||||
type TableSize = Extract<SizeVariants, "md" | "lg">;
|
||||
type TableVariant = "rows" | "cards";
|
||||
type TableQualifier = "simple" | "avatar" | "icon";
|
||||
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
|
||||
|
||||
interface TableProps
|
||||
extends WithoutStyles<React.TableHTMLAttributes<HTMLTableElement>> {
|
||||
ref?: React.Ref<HTMLTableElement>;
|
||||
/** Size preset for the table. @default "lg" */
|
||||
size?: TableSize;
|
||||
/** Visual row variant. @default "cards" */
|
||||
variant?: TableVariant;
|
||||
/** Row selection behavior. @default "no-select" */
|
||||
selectionBehavior?: SelectionBehavior;
|
||||
/** Leading qualifier column type. @default null */
|
||||
qualifier?: TableQualifier;
|
||||
/** Height behavior. `"fit"` = shrink to content, `"full"` = fill available space. */
|
||||
heightVariant?: ExtremaSizeVariants;
|
||||
/** Explicit pixel width for the table (e.g. from `table.getTotalSize()`).
|
||||
@@ -36,13 +38,14 @@ interface TableProps
|
||||
|
||||
function Table({
|
||||
ref,
|
||||
size = "lg",
|
||||
variant = "cards",
|
||||
selectionBehavior = "no-select",
|
||||
qualifier = "simple",
|
||||
heightVariant,
|
||||
width,
|
||||
...props
|
||||
}: TableProps) {
|
||||
const size = useTableSize();
|
||||
return (
|
||||
<table
|
||||
ref={ref}
|
||||
@@ -51,6 +54,7 @@ function Table({
|
||||
data-size={size}
|
||||
data-variant={variant}
|
||||
data-selection={selectionBehavior}
|
||||
data-qualifier={qualifier}
|
||||
data-height={heightVariant}
|
||||
{...props}
|
||||
/>
|
||||
@@ -58,4 +62,10 @@ function Table({
|
||||
}
|
||||
|
||||
export default Table;
|
||||
export type { TableProps, TableSize, TableVariant, SelectionBehavior };
|
||||
export type {
|
||||
TableProps,
|
||||
TableSize,
|
||||
TableVariant,
|
||||
TableQualifier,
|
||||
SelectionBehavior,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgChevronDown, SvgChevronUp, SvgHandle, SvgSort } from "@opal/icons";
|
||||
@@ -29,6 +30,8 @@ interface TableHeadCustomProps {
|
||||
icon?: (sorted: SortDirection) => IconFunctionComponent;
|
||||
/** Text alignment for the column. Defaults to `"left"`. */
|
||||
alignment?: "left" | "center" | "right";
|
||||
/** Cell density. `"md"` uses tighter padding for denser layouts. */
|
||||
size?: TableSize;
|
||||
/** Column width in pixels. Applied as an inline style on the `<th>`. */
|
||||
width?: number;
|
||||
/** When `true`, shows a bottom border on hover. Defaults to `true`. */
|
||||
@@ -78,11 +81,13 @@ export default function TableHead({
|
||||
resizable,
|
||||
onResizeStart,
|
||||
alignment = "left",
|
||||
size,
|
||||
width,
|
||||
bottomBorder = true,
|
||||
...thProps
|
||||
}: TableHeadProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
const isSmall = resolvedSize === "md";
|
||||
return (
|
||||
<th
|
||||
|
||||
@@ -3,13 +3,19 @@
|
||||
import React from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { QualifierContentType } from "@opal/components/table/types";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
interface TableQualifierProps {
|
||||
className?: string;
|
||||
/** Content type displayed in the qualifier */
|
||||
content: QualifierContentType;
|
||||
/** Size variant */
|
||||
size?: TableSize;
|
||||
/** Disables interaction */
|
||||
disabled?: boolean;
|
||||
/** Whether to show a selection checkbox overlay */
|
||||
@@ -18,33 +24,54 @@ interface TableQualifierProps {
|
||||
selected?: boolean;
|
||||
/** Called when the checkbox is toggled */
|
||||
onSelectChange?: (selected: boolean) => void;
|
||||
/** Icon component to render (for "icon" content). */
|
||||
/** Icon component to render (for "icon" content type) */
|
||||
icon?: IconFunctionComponent;
|
||||
/** Image source URL (for "image" content). */
|
||||
/** Image source URL (for "image" content type) */
|
||||
imageSrc?: string;
|
||||
/** Image alt text (for "image" content). */
|
||||
/** Image alt text */
|
||||
imageAlt?: string;
|
||||
/** Show a tinted background container behind the content. */
|
||||
background?: boolean;
|
||||
/** User initials (for "avatar-user" content type) */
|
||||
initials?: string;
|
||||
}
|
||||
|
||||
const iconSizes = {
|
||||
lg: 28,
|
||||
md: 24,
|
||||
lg: 16,
|
||||
md: 14,
|
||||
} as const;
|
||||
|
||||
function getOverlayStyles(selected: boolean, disabled: boolean) {
|
||||
function getQualifierStyles(selected: boolean, disabled: boolean) {
|
||||
if (disabled) {
|
||||
return selected ? "flex bg-action-link-00" : "hidden";
|
||||
return {
|
||||
container: "bg-background-neutral-03",
|
||||
icon: "stroke-text-02",
|
||||
overlay: selected ? "flex bg-action-link-00" : "hidden",
|
||||
overlayImage: selected ? "flex bg-mask-01 backdrop-blur-02" : "hidden",
|
||||
};
|
||||
}
|
||||
|
||||
if (selected) {
|
||||
return "flex bg-action-link-00";
|
||||
return {
|
||||
container: "bg-action-link-00",
|
||||
icon: "stroke-text-03",
|
||||
overlay: "flex bg-action-link-00",
|
||||
overlayImage: "flex bg-mask-01 backdrop-blur-02",
|
||||
};
|
||||
}
|
||||
return "flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01";
|
||||
|
||||
return {
|
||||
container: "bg-background-tint-01",
|
||||
icon: "stroke-text-03",
|
||||
overlay:
|
||||
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01",
|
||||
overlayImage:
|
||||
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-mask-01 group-hover/row:backdrop-blur-02 group-focus-within/row:backdrop-blur-02",
|
||||
};
|
||||
}
|
||||
|
||||
function TableQualifier({
|
||||
className,
|
||||
content,
|
||||
size,
|
||||
disabled = false,
|
||||
selectable = false,
|
||||
selected = false,
|
||||
@@ -52,67 +79,100 @@ function TableQualifier({
|
||||
icon: Icon,
|
||||
imageSrc,
|
||||
imageAlt = "",
|
||||
background = false,
|
||||
initials,
|
||||
}: TableQualifierProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
const isRound = content === "avatar-icon" || content === "avatar-user";
|
||||
const iconSize = iconSizes[resolvedSize];
|
||||
const overlayStyles = getOverlayStyles(selected, disabled);
|
||||
const styles = getQualifierStyles(selected, disabled);
|
||||
|
||||
function renderContent() {
|
||||
switch (content) {
|
||||
case "icon":
|
||||
return Icon ? <Icon size={iconSize} /> : null;
|
||||
return Icon ? <Icon size={iconSize} className={styles.icon} /> : null;
|
||||
|
||||
case "simple":
|
||||
return null;
|
||||
|
||||
case "image":
|
||||
return imageSrc ? (
|
||||
<img
|
||||
src={imageSrc}
|
||||
alt={imageAlt}
|
||||
className="h-full w-full rounded-08 object-cover"
|
||||
className={cn(
|
||||
"h-full w-full object-cover",
|
||||
isRound ? "rounded-full" : "rounded-08"
|
||||
)}
|
||||
/>
|
||||
) : null;
|
||||
|
||||
case "simple":
|
||||
case "avatar-icon":
|
||||
return <SvgUser size={iconSize} className={styles.icon} />;
|
||||
|
||||
case "avatar-user":
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center rounded-full bg-background-neutral-inverted-00",
|
||||
resolvedSize === "lg" ? "h-7 w-7" : "h-6 w-6"
|
||||
)}
|
||||
>
|
||||
<Text
|
||||
inverted
|
||||
secondaryAction
|
||||
text05
|
||||
className="select-none uppercase"
|
||||
>
|
||||
{initials}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const inner = renderContent();
|
||||
const showBackground = background && content !== "simple";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative inline-flex shrink-0 items-center justify-center",
|
||||
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
|
||||
disabled ? "cursor-not-allowed" : "cursor-default"
|
||||
disabled ? "cursor-not-allowed" : "cursor-default",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{showBackground ? (
|
||||
{/* Inner qualifier container — no background for "simple" */}
|
||||
{content !== "simple" && (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center overflow-hidden rounded-08 transition-colors",
|
||||
"flex items-center justify-center overflow-hidden transition-colors",
|
||||
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
|
||||
disabled
|
||||
? "bg-background-neutral-03"
|
||||
: selected
|
||||
? "bg-action-link-00"
|
||||
: "bg-background-tint-01"
|
||||
isRound ? "rounded-full" : "rounded-08",
|
||||
styles.container,
|
||||
content === "image" && disabled && !selected && "opacity-50"
|
||||
)}
|
||||
>
|
||||
{inner}
|
||||
{renderContent()}
|
||||
</div>
|
||||
) : (
|
||||
inner
|
||||
)}
|
||||
|
||||
{/* Selection overlay */}
|
||||
{selectable && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-0 items-center justify-center rounded-08",
|
||||
content === "simple" ? "flex" : overlayStyles
|
||||
"absolute inset-0 items-center justify-center",
|
||||
content === "simple"
|
||||
? "flex"
|
||||
: isRound
|
||||
? "rounded-full"
|
||||
: "rounded-08",
|
||||
content === "simple"
|
||||
? "flex"
|
||||
: content === "image"
|
||||
? styles.overlayImage
|
||||
: styles.overlay
|
||||
)}
|
||||
>
|
||||
<Checkbox
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
import { useSortable } from "@dnd-kit/sortable";
|
||||
import { CSS } from "@dnd-kit/utilities";
|
||||
@@ -11,7 +12,7 @@ import { SvgHandle } from "@opal/icons";
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export interface TableRowProps
|
||||
interface TableRowProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLTableRowElement>> {
|
||||
ref?: React.Ref<HTMLTableRowElement>;
|
||||
selected?: boolean;
|
||||
@@ -21,6 +22,8 @@ export interface TableRowProps
|
||||
sortableId?: string;
|
||||
/** Show drag handle overlay. Defaults to true when sortableId is set. */
|
||||
showDragHandle?: boolean;
|
||||
/** Size variant for the drag handle */
|
||||
size?: TableSize;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -30,13 +33,15 @@ export interface TableRowProps
|
||||
function SortableTableRow({
|
||||
sortableId,
|
||||
showDragHandle = true,
|
||||
size,
|
||||
selected,
|
||||
disabled,
|
||||
ref: _externalRef,
|
||||
children,
|
||||
...props
|
||||
}: TableRowProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
|
||||
const {
|
||||
attributes,
|
||||
@@ -100,9 +105,10 @@ function SortableTableRow({
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function TableRow({
|
||||
function TableRow({
|
||||
sortableId,
|
||||
showDragHandle,
|
||||
size,
|
||||
selected,
|
||||
disabled,
|
||||
ref,
|
||||
@@ -113,6 +119,7 @@ export default function TableRow({
|
||||
<SortableTableRow
|
||||
sortableId={sortableId}
|
||||
showDragHandle={showDragHandle}
|
||||
size={size}
|
||||
selected={selected}
|
||||
disabled={disabled}
|
||||
ref={ref}
|
||||
@@ -131,3 +138,6 @@ export default function TableRow({
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default TableRow;
|
||||
export type { TableRowProps };
|
||||
|
||||
@@ -25,14 +25,18 @@ import type { SortDirection } from "@opal/components/table/TableHead";
|
||||
interface QualifierConfig<TData> {
|
||||
/** Content type for body-row `<TableQualifier>`. @default "simple" */
|
||||
content?: QualifierContentType;
|
||||
/** Return the icon component to render for a row (for "icon" content). */
|
||||
getContent?: (row: TData) => IconFunctionComponent;
|
||||
/** Return the image URL to render for a row (for "image" content). */
|
||||
/** Content type for the header `<TableQualifier>`. @default "simple" */
|
||||
headerContentType?: QualifierContentType;
|
||||
/** Extract initials from a row (for "avatar-user" content). */
|
||||
getInitials?: (row: TData) => string;
|
||||
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
|
||||
getIcon?: (row: TData) => IconFunctionComponent;
|
||||
/** Extract image src from a row (for "image" content). */
|
||||
getImageSrc?: (row: TData) => string;
|
||||
/** Return the image alt text for a row (for "image" content). @default "" */
|
||||
getImageAlt?: (row: TData) => string;
|
||||
/** Show a tinted background container behind the content. @default false */
|
||||
background?: boolean;
|
||||
/** Whether to show selection checkboxes on the qualifier. @default true */
|
||||
selectable?: boolean;
|
||||
/** Whether to render qualifier content in the header. @default true */
|
||||
header?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -54,6 +58,8 @@ interface DataColumnConfig<TData, TValue> {
|
||||
icon?: (sorted: SortDirection) => IconFunctionComponent;
|
||||
/** Column weight for proportional distribution. @default 20 */
|
||||
weight?: number;
|
||||
/** Minimum column width in pixels. @default 50 */
|
||||
minWidth?: number;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -126,9 +132,9 @@ interface TableColumnsBuilder<TData> {
|
||||
* ```ts
|
||||
* const tc = createTableColumns<TeamMember>();
|
||||
* const columns = [
|
||||
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
|
||||
* tc.column("name", { header: "Name", weight: 23 }),
|
||||
* tc.column("email", { header: "Email", weight: 28 }),
|
||||
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
|
||||
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
|
||||
* tc.column("email", { header: "Email", weight: 28, minWidth: 150 }),
|
||||
* tc.actions(),
|
||||
* ];
|
||||
* ```
|
||||
@@ -156,10 +162,12 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
|
||||
width: (size: TableSize) =>
|
||||
size === "md" ? { fixed: 36 } : { fixed: 44 },
|
||||
content,
|
||||
getContent: config?.getContent,
|
||||
headerContentType: config?.headerContentType,
|
||||
getInitials: config?.getInitials,
|
||||
getIcon: config?.getIcon,
|
||||
getImageSrc: config?.getImageSrc,
|
||||
getImageAlt: config?.getImageAlt,
|
||||
background: config?.background,
|
||||
selectable: config?.selectable,
|
||||
header: config?.header,
|
||||
};
|
||||
},
|
||||
|
||||
@@ -175,6 +183,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
|
||||
enableHiding = true,
|
||||
icon,
|
||||
weight = 20,
|
||||
minWidth = 50,
|
||||
} = config;
|
||||
|
||||
const def = helper.accessor(accessor as any, {
|
||||
@@ -192,7 +201,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
|
||||
kind: "data",
|
||||
id: accessor as string,
|
||||
def,
|
||||
width: { weight, minWidth: Math.max(header.length * 8 + 40, 80) },
|
||||
width: { weight, minWidth },
|
||||
icon,
|
||||
};
|
||||
},
|
||||
|
||||
@@ -39,12 +39,15 @@ import type {
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SelectionBehavior
|
||||
// Qualifier × SelectionBehavior
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type Qualifier = "simple" | "avatar" | "icon";
|
||||
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
|
||||
|
||||
export type DataTableProps<TData> = BaseDataTableProps<TData> & {
|
||||
/** Leading qualifier column type. @default "simple" */
|
||||
qualifier?: Qualifier;
|
||||
/** Row selection behavior. @default "no-select" */
|
||||
selectionBehavior?: SelectionBehavior;
|
||||
};
|
||||
@@ -128,8 +131,8 @@ function processColumns<TData>(
|
||||
* ```tsx
|
||||
* const tc = createTableColumns<TeamMember>();
|
||||
* const columns = [
|
||||
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
|
||||
* tc.column("name", { header: "Name", weight: 23 }),
|
||||
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
|
||||
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
|
||||
* tc.column("email", { header: "Email", weight: 28 }),
|
||||
* tc.actions(),
|
||||
* ];
|
||||
@@ -149,11 +152,13 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
footer,
|
||||
size = "lg",
|
||||
variant = "cards",
|
||||
qualifier = "simple",
|
||||
selectionBehavior = "no-select",
|
||||
onSelectionChange,
|
||||
onRowClick,
|
||||
searchTerm,
|
||||
height,
|
||||
headerBackground,
|
||||
serverSide,
|
||||
emptyState,
|
||||
} = props;
|
||||
@@ -161,15 +166,11 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
|
||||
|
||||
// Whether the qualifier column should exist in the DOM.
|
||||
// Derived from the column definitions: if a qualifier column exists with
|
||||
// content !== "simple", always show it. If content === "simple" (or no
|
||||
// qualifier column defined), show only for multi-select (checkboxes).
|
||||
const qualifierColDef = columns.find(
|
||||
(c): c is OnyxQualifierColumn<TData> => c.kind === "qualifier"
|
||||
);
|
||||
// "simple" only gets a qualifier column for multi-select (checkboxes).
|
||||
// "simple" + no-select/single-select = no qualifier column — single-select
|
||||
// uses row-level background coloring instead.
|
||||
const hasQualifierColumn =
|
||||
(qualifierColDef != null && qualifierColDef.content !== "simple") ||
|
||||
selectionBehavior === "multi-select";
|
||||
qualifier !== "simple" || selectionBehavior === "multi-select";
|
||||
|
||||
// 1. Process columns (memoized on columns + size)
|
||||
const { tanstackColumns, widthConfig, qualifierColumn, columnKindMap } =
|
||||
@@ -348,9 +349,15 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
overflowY: "auto" as const,
|
||||
}
|
||||
: undefined),
|
||||
...(headerBackground
|
||||
? ({
|
||||
"--table-header-bg": headerBackground,
|
||||
} as React.CSSProperties)
|
||||
: undefined),
|
||||
}}
|
||||
>
|
||||
<TableElement
|
||||
size={size}
|
||||
variant={variant}
|
||||
selectionBehavior={selectionBehavior}
|
||||
width={
|
||||
@@ -412,12 +419,14 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
columnVisibility={
|
||||
table.getState().columnVisibility
|
||||
}
|
||||
size={size}
|
||||
/>
|
||||
)}
|
||||
{actionsDef.showSorting !== false && (
|
||||
<SortingPopover
|
||||
table={table}
|
||||
sorting={table.getState().sorting}
|
||||
size={size}
|
||||
footerText={actionsDef.sortingFooterText}
|
||||
/>
|
||||
)}
|
||||
@@ -532,6 +541,12 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
if (cellColDef?.kind === "qualifier") {
|
||||
const qDef = cellColDef as OnyxQualifierColumn<TData>;
|
||||
|
||||
// Resolve content based on the qualifier prop:
|
||||
// - "simple" renders nothing (checkbox only when selectable)
|
||||
// - "avatar"/"icon" render from column config
|
||||
const qualifierContent =
|
||||
qualifier === "simple" ? "simple" : qDef.content;
|
||||
|
||||
return (
|
||||
<QualifierContainer
|
||||
key={cell.id}
|
||||
@@ -539,11 +554,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<TableQualifier
|
||||
content={qDef.content}
|
||||
icon={qDef.getContent?.(row.original)}
|
||||
content={qualifierContent}
|
||||
initials={qDef.getInitials?.(row.original)}
|
||||
icon={qDef.getIcon?.(row.original)}
|
||||
imageSrc={qDef.getImageSrc?.(row.original)}
|
||||
imageAlt={qDef.getImageAlt?.(row.original)}
|
||||
background={qDef.background}
|
||||
selectable={showQualifierCheckbox}
|
||||
selected={
|
||||
showQualifierCheckbox && row.getIsSelected()
|
||||
|
||||
@@ -277,7 +277,7 @@ function createSplitterResizeHandler(
|
||||
* const { containerRef, columnWidths, createResizeHandler } = useColumnWidths({
|
||||
* headers: table.getHeaderGroups()[0].headers,
|
||||
* fixedColumnIds: new Set(["actions"]),
|
||||
* columnMinWidths: { name: 72, status: 80 },
|
||||
* columnMinWidths: { name: 120, status: 80 },
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
|
||||
@@ -25,7 +25,8 @@
|
||||
/* ---- TableHead ---- */
|
||||
|
||||
.table-head {
|
||||
@apply relative;
|
||||
@apply relative sticky top-0 z-20;
|
||||
background: var(--table-header-bg, transparent);
|
||||
}
|
||||
.table-head[data-size="lg"] {
|
||||
@apply px-2 py-1;
|
||||
@@ -129,7 +130,8 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
|
||||
/* ---- QualifierContainer ---- */
|
||||
|
||||
.tbl-qualifier[data-type="head"] {
|
||||
@apply w-px whitespace-nowrap py-1;
|
||||
@apply w-px whitespace-nowrap py-1 sticky top-0 z-20;
|
||||
background: var(--table-header-bg, transparent);
|
||||
}
|
||||
.tbl-qualifier[data-type="head"][data-size="md"] {
|
||||
@apply py-0.5;
|
||||
@@ -145,10 +147,11 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
|
||||
/* ---- ActionsContainer ---- */
|
||||
|
||||
.tbl-actions {
|
||||
@apply w-px whitespace-nowrap px-1;
|
||||
@apply sticky right-0 w-px whitespace-nowrap px-1;
|
||||
}
|
||||
.tbl-actions[data-type="head"] {
|
||||
@apply px-2 py-1;
|
||||
@apply z-30 sticky top-0 px-2 py-1;
|
||||
background: var(--table-header-bg, transparent);
|
||||
}
|
||||
|
||||
/* ---- Footer ---- */
|
||||
|
||||
@@ -30,7 +30,12 @@ export type ColumnWidth = DataColumnWidth | FixedColumnWidth;
|
||||
// Column kind discriminant
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export type QualifierContentType = "simple" | "icon" | "image";
|
||||
export type QualifierContentType =
|
||||
| "icon"
|
||||
| "simple"
|
||||
| "image"
|
||||
| "avatar-icon"
|
||||
| "avatar-user";
|
||||
|
||||
export type OnyxColumnKind = "qualifier" | "data" | "display" | "actions";
|
||||
|
||||
@@ -51,14 +56,18 @@ export interface OnyxQualifierColumn<TData> extends OnyxColumnBase<TData> {
|
||||
kind: "qualifier";
|
||||
/** Content type for body-row `<TableQualifier>`. */
|
||||
content: QualifierContentType;
|
||||
/** Return the icon component to render for a row (for "icon" content). */
|
||||
getContent?: (row: TData) => IconFunctionComponent;
|
||||
/** Return the image URL to render for a row (for "image" content). */
|
||||
/** Content type for the header `<TableQualifier>`. @default "simple" */
|
||||
headerContentType?: QualifierContentType;
|
||||
/** Extract initials from a row (for "avatar-user" content). */
|
||||
getInitials?: (row: TData) => string;
|
||||
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
|
||||
getIcon?: (row: TData) => IconFunctionComponent;
|
||||
/** Extract image src from a row (for "image" content). */
|
||||
getImageSrc?: (row: TData) => string;
|
||||
/** Return the image alt text for a row (for "image" content). @default "" */
|
||||
getImageAlt?: (row: TData) => string;
|
||||
/** Show a tinted background container behind the content. @default false */
|
||||
background?: boolean;
|
||||
/** Whether to show selection checkboxes on the qualifier. @default true */
|
||||
selectable?: boolean;
|
||||
/** Whether to render qualifier content in the header. @default true */
|
||||
header?: boolean;
|
||||
}
|
||||
|
||||
/** Data column — accessor-based column with sorting/resizing. */
|
||||
@@ -165,6 +174,9 @@ export interface DataTableProps<TData> {
|
||||
* Accepts a pixel number (e.g. `300`) or a CSS value string (e.g. `"50vh"`).
|
||||
*/
|
||||
height?: number | string;
|
||||
/** Background color for the sticky header row, preventing rows from showing
|
||||
* through when scrolling. Accepts any CSS color value. */
|
||||
headerBackground?: string;
|
||||
/**
|
||||
* Enable server-side mode. When provided:
|
||||
* - TanStack uses manualPagination/manualSorting/manualFiltering
|
||||
|
||||
6
web/package-lock.json
generated
6
web/package-lock.json
generated
@@ -10309,9 +10309,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/flatted": {
|
||||
"version": "3.4.2",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
|
||||
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
|
||||
"version": "3.4.1",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.1.tgz",
|
||||
"integrity": "sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
|
||||
@@ -45,7 +45,7 @@ function MemoryTagWithTooltip({
|
||||
<MemoriesModal
|
||||
initialTargetMemoryId={memoryId}
|
||||
initialTargetIndex={memoryIndex}
|
||||
highlightOnOpen
|
||||
highlightFirstOnOpen
|
||||
/>
|
||||
</memoriesModal.Provider>
|
||||
{memoriesModal.isOpen ? (
|
||||
@@ -56,14 +56,8 @@ function MemoryTagWithTooltip({
|
||||
side="bottom"
|
||||
className="bg-background-neutral-00 text-text-01 shadow-md max-w-[17.5rem] p-1"
|
||||
tooltip={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="start"
|
||||
padding={0.25}
|
||||
gap={0.25}
|
||||
height="auto"
|
||||
>
|
||||
<div className="p-1">
|
||||
<Section flexDirection="column" gap={0.25} height="auto">
|
||||
<div className="p-1 w-full">
|
||||
<Text as="p" secondaryBody text03>
|
||||
{memoryText}
|
||||
</Text>
|
||||
@@ -72,7 +66,6 @@ function MemoryTagWithTooltip({
|
||||
icon={SvgAddLines}
|
||||
title={operationLabel}
|
||||
sizePreset="secondary"
|
||||
paddingVariant="sm"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
rightChildren={
|
||||
|
||||
@@ -103,7 +103,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
|
||||
<MemoriesModal
|
||||
initialTargetMemoryId={memoryId}
|
||||
initialTargetIndex={index}
|
||||
highlightOnOpen
|
||||
highlightFirstOnOpen
|
||||
/>
|
||||
</memoriesModal.Provider>
|
||||
{memoryText ? (
|
||||
|
||||
@@ -7,8 +7,11 @@ import { processRawChatHistory } from "@/app/app/services/lib";
|
||||
import { getLatestMessageChain } from "@/app/app/services/messageTree";
|
||||
import HumanMessage from "@/app/app/message/HumanMessage";
|
||||
import AgentMessage from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import OnyxInitializingLoader from "@/components/OnyxInitializingLoader";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import SvgNotFound from "@opal/illustrations/not-found";
|
||||
import { Button } from "@opal/components";
|
||||
import { Persona } from "@/app/admin/agents/interfaces";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import PreviewModal from "@/sections/modals/PreviewModal";
|
||||
@@ -33,12 +36,17 @@ export default function SharedChatDisplay({
|
||||
|
||||
if (!chatSession) {
|
||||
return (
|
||||
<div className="min-h-full w-full">
|
||||
<div className="mx-auto w-fit pt-8">
|
||||
<Callout type="danger" title="Shared Chat Not Found">
|
||||
Did not find a shared chat with the specified ID.
|
||||
</Callout>
|
||||
</div>
|
||||
<div className="h-full w-full flex flex-col items-center justify-center">
|
||||
<Section flexDirection="column" alignItems="center" gap={1}>
|
||||
<IllustrationContent
|
||||
illustration={SvgNotFound}
|
||||
title="Shared chat not found"
|
||||
description="Did not find a shared chat with the specified ID."
|
||||
/>
|
||||
<Button href="/app" prominence="secondary">
|
||||
Start a new chat
|
||||
</Button>
|
||||
</Section>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -51,12 +59,17 @@ export default function SharedChatDisplay({
|
||||
|
||||
if (firstMessage === undefined) {
|
||||
return (
|
||||
<div className="min-h-full w-full">
|
||||
<div className="mx-auto w-fit pt-8">
|
||||
<Callout type="danger" title="Shared Chat Not Found">
|
||||
No messages found in shared chat.
|
||||
</Callout>
|
||||
</div>
|
||||
<div className="h-full w-full flex flex-col items-center justify-center">
|
||||
<Section flexDirection="column" alignItems="center" gap={1}>
|
||||
<IllustrationContent
|
||||
illustration={SvgNotFound}
|
||||
title="Shared chat not found"
|
||||
description="No messages found in shared chat."
|
||||
/>
|
||||
<Button href="/app" prominence="secondary">
|
||||
Start a new chat
|
||||
</Button>
|
||||
</Section>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -61,6 +61,11 @@ interface UseChatSessionControllerProps {
|
||||
}) => Promise<void>;
|
||||
}
|
||||
|
||||
export type SessionFetchError = {
|
||||
type: "not_found" | "access_denied" | "unknown";
|
||||
detail: string;
|
||||
} | null;
|
||||
|
||||
export default function useChatSessionController({
|
||||
existingChatSessionId,
|
||||
searchParams,
|
||||
@@ -80,6 +85,8 @@ export default function useChatSessionController({
|
||||
const [currentSessionFileTokenCount, setCurrentSessionFileTokenCount] =
|
||||
useState<number>(0);
|
||||
const [projectFiles, setProjectFiles] = useState<ProjectFile[]>([]);
|
||||
const [sessionFetchError, setSessionFetchError] =
|
||||
useState<SessionFetchError>(null);
|
||||
// Store actions
|
||||
const updateSessionAndMessageTree = useChatSessionStore(
|
||||
(state) => state.updateSessionAndMessageTree
|
||||
@@ -151,6 +158,8 @@ export default function useChatSessionController({
|
||||
}
|
||||
|
||||
async function initialSessionFetch() {
|
||||
setSessionFetchError(null);
|
||||
|
||||
if (existingChatSessionId === null) {
|
||||
// Clear the current session in the store to show intro messages
|
||||
setCurrentSession(null);
|
||||
@@ -178,9 +187,42 @@ export default function useChatSessionController({
|
||||
setCurrentSession(existingChatSessionId);
|
||||
setIsFetchingChatMessages(existingChatSessionId, true);
|
||||
|
||||
const response = await fetch(
|
||||
`/api/chat/get-chat-session/${existingChatSessionId}`
|
||||
);
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(
|
||||
`/api/chat/get-chat-session/${existingChatSessionId}`
|
||||
);
|
||||
} catch (error) {
|
||||
setIsFetchingChatMessages(existingChatSessionId, false);
|
||||
console.error("Failed to fetch chat session", {
|
||||
chatSessionId: existingChatSessionId,
|
||||
error,
|
||||
});
|
||||
setSessionFetchError({
|
||||
type: "unknown",
|
||||
detail: "Failed to load chat session. Please check your connection.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
setIsFetchingChatMessages(existingChatSessionId, false);
|
||||
let detail = "An unexpected error occurred.";
|
||||
try {
|
||||
const errorBody = await response.json();
|
||||
detail = errorBody.detail || detail;
|
||||
} catch {
|
||||
// ignore parse errors
|
||||
}
|
||||
const type =
|
||||
response.status === 404
|
||||
? "not_found"
|
||||
: response.status === 403
|
||||
? "access_denied"
|
||||
: "unknown";
|
||||
setSessionFetchError({ type, detail });
|
||||
return;
|
||||
}
|
||||
|
||||
const session = await response.json();
|
||||
const chatSession = session as BackendChatSession;
|
||||
@@ -356,5 +398,6 @@ export default function useChatSessionController({
|
||||
currentSessionFileTokenCount,
|
||||
onMessageSelection,
|
||||
projectFiles,
|
||||
sessionFetchError,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -20,15 +20,10 @@
|
||||
|
||||
"use client";
|
||||
|
||||
import {
|
||||
cn,
|
||||
ensureHrefProtocol,
|
||||
INTERACTIVE_SELECTOR,
|
||||
noProp,
|
||||
} from "@/lib/utils";
|
||||
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
|
||||
import type { Components } from "react-markdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useCallback, useMemo, useRef, useState, useEffect } from "react";
|
||||
import { useCallback, useMemo, useState, useEffect } from "react";
|
||||
import { useAppBackground } from "@/providers/AppBackgroundProvider";
|
||||
import { useTheme } from "next-themes";
|
||||
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
|
||||
@@ -537,37 +532,6 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
const { isSafari } = useBrowserInfo();
|
||||
const isLightMode = resolvedTheme === "light";
|
||||
const showBackground = hasBackground && enableBackground;
|
||||
|
||||
// Track whether the chat input was focused before a mousedown, so we can
|
||||
// restore focus on mouseup if no text was selected. This preserves
|
||||
// click-drag text selection while keeping the input focused on plain clicks.
|
||||
const inputWasFocused = useRef(false);
|
||||
|
||||
const handleMouseDown = useCallback(
|
||||
(event: React.MouseEvent<HTMLDivElement>) => {
|
||||
const activeEl = document.activeElement;
|
||||
const isFocused =
|
||||
activeEl instanceof HTMLElement &&
|
||||
activeEl.id === "onyx-chat-input-textarea";
|
||||
const target = event.target;
|
||||
const isInteractive =
|
||||
target instanceof HTMLElement && !!target.closest(INTERACTIVE_SELECTOR);
|
||||
inputWasFocused.current = isFocused && !isInteractive;
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleMouseUp = useCallback(() => {
|
||||
if (!inputWasFocused.current) return;
|
||||
inputWasFocused.current = false;
|
||||
const sel = window.getSelection();
|
||||
if (sel && !sel.isCollapsed) return;
|
||||
const textarea = document.getElementById("onyx-chat-input-textarea");
|
||||
// Only restore focus if no other element has grabbed it since mousedown.
|
||||
if (textarea && document.activeElement !== textarea) {
|
||||
textarea.focus();
|
||||
}
|
||||
}, []);
|
||||
const horizontalBlurMask = `linear-gradient(
|
||||
to right,
|
||||
transparent 0%,
|
||||
@@ -585,8 +549,6 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
*/
|
||||
<div
|
||||
data-main-container
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseUp={handleMouseUp}
|
||||
className={cn(
|
||||
"@container flex flex-col h-full w-full relative overflow-hidden",
|
||||
showBackground && "bg-cover bg-center bg-fixed"
|
||||
|
||||
@@ -125,7 +125,7 @@ export const MAX_FILES_TO_SHOW = 3;
|
||||
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
|
||||
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
|
||||
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
|
||||
export const DEFAULT_AVATAR_SIZE_PX = 18;
|
||||
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
|
||||
export const HORIZON_DISTANCE_PX = 800;
|
||||
export const LOGO_FOLDED_SIZE_PX = 24;
|
||||
export const LOGO_UNFOLDED_SIZE_PX = 88;
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import { getUserInitials } from "@/lib/user";
|
||||
|
||||
describe("getUserInitials", () => {
|
||||
it("returns first letters of first two name parts", () => {
|
||||
expect(getUserInitials("Alice Smith", "alice@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("returns first two chars of a single-word name", () => {
|
||||
expect(getUserInitials("Alice", "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("handles three-word names (uses first two)", () => {
|
||||
expect(getUserInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
|
||||
});
|
||||
|
||||
it("falls back to email local part with dot separator", () => {
|
||||
expect(getUserInitials(null, "alice.smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with underscore separator", () => {
|
||||
expect(getUserInitials(null, "alice_smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with hyphen separator", () => {
|
||||
expect(getUserInitials(null, "alice-smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("uses first two chars of email local if no separator", () => {
|
||||
expect(getUserInitials(null, "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("returns null for empty email local part", () => {
|
||||
expect(getUserInitials(null, "@example.com")).toBeNull();
|
||||
});
|
||||
|
||||
it("uppercases the result", () => {
|
||||
expect(getUserInitials("john doe", "jd@test.com")).toBe("JD");
|
||||
});
|
||||
|
||||
it("trims whitespace from name", () => {
|
||||
expect(getUserInitials(" Alice Smith ", "a@test.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("returns null for numeric name parts", () => {
|
||||
expect(getUserInitials("Alice 1st", "x@test.com")).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for numeric email", () => {
|
||||
expect(getUserInitials(null, "42@domain.com")).toBeNull();
|
||||
});
|
||||
|
||||
it("falls back to email when name has non-alpha chars", () => {
|
||||
expect(getUserInitials("A1", "alice@example.com")).toBe("AL");
|
||||
});
|
||||
});
|
||||
@@ -128,54 +128,3 @@ export function getUserDisplayName(user: User | null): string {
|
||||
// If nothing works, then fall back to anonymous user name
|
||||
return "Anonymous";
|
||||
}
|
||||
|
||||
/**
|
||||
* Derive display initials from a user's name or email.
|
||||
*
|
||||
* - If a name is provided, uses the first letter of the first two words.
|
||||
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
|
||||
* - Returns `null` when no valid alpha initials can be derived.
|
||||
*/
|
||||
export function getUserInitials(
|
||||
name: string | null,
|
||||
email: string
|
||||
): string | null {
|
||||
if (name) {
|
||||
const words = name.trim().split(/\s+/);
|
||||
if (words.length >= 2) {
|
||||
const first = words[0]?.[0];
|
||||
const second = words[1]?.[0];
|
||||
if (first && second) {
|
||||
const result = (first + second).toUpperCase();
|
||||
if (/^[A-Z]{2}$/.test(result)) return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (name.trim().length >= 1) {
|
||||
const result = name.trim().slice(0, 2).toUpperCase();
|
||||
if (/^[A-Z]{1,2}$/.test(result)) return result;
|
||||
}
|
||||
}
|
||||
|
||||
const local = email.split("@")[0];
|
||||
if (!local || local.length === 0) return null;
|
||||
const parts = local.split(/[._-]/);
|
||||
if (parts.length >= 2) {
|
||||
const first = parts[0]?.[0];
|
||||
const second = parts[1]?.[0];
|
||||
if (first && second) {
|
||||
const result = (first + second).toUpperCase();
|
||||
if (/^[A-Z]{2}$/.test(result)) return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (local.length >= 2) {
|
||||
const result = local.slice(0, 2).toUpperCase();
|
||||
if (/^[A-Z]{2}$/.test(result)) return result;
|
||||
}
|
||||
if (local.length === 1) {
|
||||
const result = local.toUpperCase();
|
||||
if (/^[A-Z]$/.test(result)) return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -13,9 +13,6 @@ import { ALLOWED_URL_PROTOCOLS } from "./constants";
|
||||
const URI_SCHEME_REGEX = /^[a-zA-Z][a-zA-Z\d+.-]*:/;
|
||||
const BARE_EMAIL_REGEX = /^[^\s@/]+@[^\s@/:]+\.[^\s@/:]+$/;
|
||||
|
||||
export const INTERACTIVE_SELECTOR =
|
||||
"a, button, input, textarea, select, label, [role='button'], [tabindex]:not([tabindex='-1']), [contenteditable]:not([contenteditable='false'])";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
import { buildImgUrl } from "@/app/app/components/files/images/utils";
|
||||
import { OnyxIcon } from "@/components/icons/icons";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { DEFAULT_AVATAR_SIZE_PX, DEFAULT_AGENT_ID } from "@/lib/constants";
|
||||
import {
|
||||
DEFAULT_AGENT_AVATAR_SIZE_PX,
|
||||
DEFAULT_AGENT_ID,
|
||||
} from "@/lib/constants";
|
||||
import CustomAgentAvatar from "@/refresh-components/avatars/CustomAgentAvatar";
|
||||
import Image from "next/image";
|
||||
|
||||
@@ -15,7 +18,7 @@ export interface AgentAvatarProps {
|
||||
|
||||
export default function AgentAvatar({
|
||||
agent,
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
|
||||
...props
|
||||
}: AgentAvatarProps) {
|
||||
const settings = useSettingsContext();
|
||||
|
||||
@@ -4,7 +4,7 @@ import { cn } from "@/lib/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Image from "next/image";
|
||||
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import { DEFAULT_AGENT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import {
|
||||
SvgActivitySmall,
|
||||
SvgAudioEqSmall,
|
||||
@@ -96,7 +96,7 @@ export default function CustomAgentAvatar({
|
||||
src,
|
||||
iconName,
|
||||
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
|
||||
}: CustomAgentAvatarProps) {
|
||||
if (src) {
|
||||
return (
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import { SvgUser } from "@opal/icons";
|
||||
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import { getUserInitials } from "@/lib/user";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import type { User } from "@/lib/types";
|
||||
|
||||
export interface UserAvatarProps {
|
||||
user: User;
|
||||
size?: number;
|
||||
}
|
||||
|
||||
export default function UserAvatar({
|
||||
user,
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
}: UserAvatarProps) {
|
||||
const initials = getUserInitials(
|
||||
user.personalization?.name ?? null,
|
||||
user.email
|
||||
);
|
||||
|
||||
if (!initials) {
|
||||
return (
|
||||
<div
|
||||
role="img"
|
||||
aria-label={`${user.email} avatar`}
|
||||
className="flex items-center justify-center rounded-full bg-background-tint-01"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
<SvgUser size={size * 0.55} className="stroke-text-03" aria-hidden />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
role="img"
|
||||
aria-label={`${user.email} avatar`}
|
||||
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
<Text
|
||||
inverted
|
||||
secondaryAction
|
||||
text05
|
||||
className="select-none"
|
||||
style={{ fontSize: size * 0.4 }}
|
||||
>
|
||||
{initials}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -54,10 +54,8 @@ function MemoryItem({
|
||||
const wrapperRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldFocus && textareaRef.current) {
|
||||
const el = textareaRef.current;
|
||||
el.focus();
|
||||
el.selectionStart = el.selectionEnd = el.value.length;
|
||||
if (shouldFocus) {
|
||||
textareaRef.current?.focus();
|
||||
onFocused?.();
|
||||
}
|
||||
}, [shouldFocus, onFocused]);
|
||||
@@ -65,10 +63,8 @@ function MemoryItem({
|
||||
useEffect(() => {
|
||||
if (!shouldHighlight) return;
|
||||
|
||||
wrapperRef.current?.scrollIntoView({
|
||||
block: "start",
|
||||
behavior: "smooth",
|
||||
});
|
||||
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
|
||||
textareaRef.current?.focus();
|
||||
setIsHighlighting(true);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
@@ -83,10 +79,10 @@ function MemoryItem({
|
||||
<div
|
||||
ref={wrapperRef}
|
||||
className={cn(
|
||||
"rounded-08 w-full p-0.5 border border-transparent",
|
||||
"rounded-08 hover:bg-background-tint-00 w-full p-0.5",
|
||||
"transition-colors ",
|
||||
isHighlighting &&
|
||||
"bg-action-link-01 hover:bg-action-link-01 border-action-link-05 duration-700"
|
||||
"bg-action-link-01 border border-action-link-05 duration-700"
|
||||
)}
|
||||
>
|
||||
<Section gap={0.25} alignItems="start">
|
||||
@@ -104,7 +100,7 @@ function MemoryItem({
|
||||
rows={3}
|
||||
maxLength={MAX_MEMORY_LENGTH}
|
||||
resizable={false}
|
||||
className="bg-background-tint-01 hover:bg-background-tint-00 focus-within:bg-background-tint-00"
|
||||
className={cn(!isFocused && "bg-transparent")}
|
||||
/>
|
||||
<Disabled disabled={!memory.content.trim() && memory.isNew}>
|
||||
<Button
|
||||
@@ -126,29 +122,13 @@ function MemoryItem({
|
||||
);
|
||||
}
|
||||
|
||||
function resolveTargetMemoryId(
|
||||
targetMemoryId: number | null | undefined,
|
||||
targetIndex: number | null | undefined,
|
||||
memories: MemoryItem[]
|
||||
): number | null {
|
||||
if (targetMemoryId != null) return targetMemoryId;
|
||||
|
||||
if (targetIndex != null && memories.length > 0) {
|
||||
// Backend index is ASC (oldest-first), frontend displays DESC (newest-first)
|
||||
const descIdx = memories.length - 1 - targetIndex;
|
||||
return memories[descIdx]?.id ?? null;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
interface MemoriesModalProps {
|
||||
memories?: MemoryItem[];
|
||||
onSaveMemories?: (memories: MemoryItem[]) => Promise<boolean>;
|
||||
onClose?: () => void;
|
||||
initialTargetMemoryId?: number | null;
|
||||
initialTargetIndex?: number | null;
|
||||
highlightOnOpen?: boolean;
|
||||
highlightFirstOnOpen?: boolean;
|
||||
}
|
||||
|
||||
export default function MemoriesModal({
|
||||
@@ -157,7 +137,7 @@ export default function MemoriesModal({
|
||||
onClose,
|
||||
initialTargetMemoryId,
|
||||
initialTargetIndex,
|
||||
highlightOnOpen = false,
|
||||
highlightFirstOnOpen = false,
|
||||
}: MemoriesModalProps) {
|
||||
const close = useModalClose(onClose);
|
||||
const [focusMemoryId, setFocusMemoryId] = useState<number | null>(null);
|
||||
@@ -202,16 +182,24 @@ export default function MemoriesModal({
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const targetId = resolveTargetMemoryId(
|
||||
initialTargetMemoryId,
|
||||
initialTargetIndex,
|
||||
effectiveMemories
|
||||
);
|
||||
if (targetId == null) return;
|
||||
|
||||
setFocusMemoryId(targetId);
|
||||
if (highlightOnOpen) {
|
||||
setHighlightMemoryId(targetId);
|
||||
if (initialTargetMemoryId != null) {
|
||||
// Direct DB id available — use it
|
||||
setHighlightMemoryId(initialTargetMemoryId);
|
||||
} else if (initialTargetIndex != null && effectiveMemories.length > 0) {
|
||||
// Backend index is ASC (oldest-first), but the frontend displays DESC
|
||||
// (newest-first). Convert: descIdx = totalCount - 1 - ascIdx
|
||||
const descIdx = effectiveMemories.length - 1 - initialTargetIndex;
|
||||
const target = effectiveMemories[descIdx];
|
||||
if (target) {
|
||||
setHighlightMemoryId(target.id);
|
||||
}
|
||||
} else if (
|
||||
highlightFirstOnOpen &&
|
||||
effectiveMemories.length > 0 &&
|
||||
effectiveMemories[0]
|
||||
) {
|
||||
// Fallback: highlight the first displayed item (newest)
|
||||
setHighlightMemoryId(effectiveMemories[0].id);
|
||||
}
|
||||
}, [initialTargetMemoryId, initialTargetIndex]);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import { personaIncludesRetrieval } from "@/app/app/services/lib";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast, useToastFromQuery } from "@/hooks/useToast";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { useFederatedConnectors, useFilters, useLlmManager } from "@/lib/hooks";
|
||||
import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import OnyxInitializingLoader from "@/components/OnyxInitializingLoader";
|
||||
@@ -62,6 +63,9 @@ import { useShowOnboarding } from "@/hooks/useShowOnboarding";
|
||||
import * as AppLayouts from "@/layouts/app-layouts";
|
||||
import { SvgChevronDown, SvgFileText } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import SvgNotFound from "@opal/illustrations/not-found";
|
||||
import SvgNoAccess from "@opal/illustrations/no-access";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
@@ -381,23 +385,26 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
setSelectedAgentFromId,
|
||||
});
|
||||
|
||||
const { onMessageSelection, currentSessionFileTokenCount } =
|
||||
useChatSessionController({
|
||||
existingChatSessionId: currentChatSessionId,
|
||||
searchParams,
|
||||
filterManager,
|
||||
firstMessage,
|
||||
setSelectedAgentFromId,
|
||||
setSelectedDocuments,
|
||||
setCurrentMessageFiles,
|
||||
chatSessionIdRef,
|
||||
loadedIdSessionRef,
|
||||
chatInputBarRef,
|
||||
isInitialLoad,
|
||||
submitOnLoadPerformed,
|
||||
refreshChatSessions,
|
||||
onSubmit,
|
||||
});
|
||||
const {
|
||||
onMessageSelection,
|
||||
currentSessionFileTokenCount,
|
||||
sessionFetchError,
|
||||
} = useChatSessionController({
|
||||
existingChatSessionId: currentChatSessionId,
|
||||
searchParams,
|
||||
filterManager,
|
||||
firstMessage,
|
||||
setSelectedAgentFromId,
|
||||
setSelectedDocuments,
|
||||
setCurrentMessageFiles,
|
||||
chatSessionIdRef,
|
||||
loadedIdSessionRef,
|
||||
chatInputBarRef,
|
||||
isInitialLoad,
|
||||
submitOnLoadPerformed,
|
||||
refreshChatSessions,
|
||||
onSubmit,
|
||||
});
|
||||
|
||||
useSendMessageToParent();
|
||||
|
||||
@@ -679,7 +686,10 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
{/* ChatUI */}
|
||||
<Fade
|
||||
show={
|
||||
appFocus.isChat() && !!currentChatSessionId && !!liveAgent
|
||||
appFocus.isChat() &&
|
||||
!!currentChatSessionId &&
|
||||
!!liveAgent &&
|
||||
!sessionFetchError
|
||||
}
|
||||
className="h-full w-full flex flex-col items-center"
|
||||
>
|
||||
@@ -708,6 +718,45 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
</ChatScrollContainer>
|
||||
</Fade>
|
||||
|
||||
{/* Session fetch error (404 / 403) */}
|
||||
<Fade
|
||||
show={appFocus.isChat() && sessionFetchError !== null}
|
||||
className="h-full w-full flex flex-col items-center justify-center"
|
||||
>
|
||||
{sessionFetchError && (
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
gap={1}
|
||||
>
|
||||
<IllustrationContent
|
||||
illustration={
|
||||
sessionFetchError.type === "access_denied"
|
||||
? SvgNoAccess
|
||||
: SvgNotFound
|
||||
}
|
||||
title={
|
||||
sessionFetchError.type === "not_found"
|
||||
? "Chat not found"
|
||||
: sessionFetchError.type === "access_denied"
|
||||
? "Access denied"
|
||||
: "Something went wrong"
|
||||
}
|
||||
description={
|
||||
sessionFetchError.type === "not_found"
|
||||
? "This chat session doesn't exist or has been deleted."
|
||||
: sessionFetchError.type === "access_denied"
|
||||
? "You don't have permission to view this chat session."
|
||||
: sessionFetchError.detail
|
||||
}
|
||||
/>
|
||||
<Button href="/app" prominence="secondary">
|
||||
Start a new chat
|
||||
</Button>
|
||||
</Section>
|
||||
)}
|
||||
</Fade>
|
||||
|
||||
{/* ProjectUI */}
|
||||
{appFocus.isProject() && (
|
||||
<div className="w-full max-h-[50vh] overflow-y-auto overscroll-y-none">
|
||||
@@ -736,7 +785,12 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
</div>
|
||||
|
||||
{/* ── Middle-center: AppInputBar ── */}
|
||||
<div className="row-start-2 flex flex-col items-center px-4">
|
||||
<div
|
||||
className={cn(
|
||||
"row-start-2 flex flex-col items-center px-4",
|
||||
sessionFetchError && "hidden"
|
||||
)}
|
||||
>
|
||||
<div className="relative w-full max-w-[var(--app-page-main-content-width)] flex flex-col">
|
||||
{/* Scroll to bottom button - positioned absolutely above AppInputBar */}
|
||||
{appFocus.isChat() && showScrollButton && (
|
||||
|
||||
@@ -26,8 +26,7 @@ import type {
|
||||
StatusFilter,
|
||||
StatusCountMap,
|
||||
} from "./interfaces";
|
||||
import UserAvatar from "@/refresh-components/avatars/UserAvatar";
|
||||
import type { User } from "@/lib/types";
|
||||
import { getInitials } from "./utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Column renderers
|
||||
@@ -76,25 +75,20 @@ const tc = createTableColumns<UserRow>();
|
||||
function buildColumns(onMutate: () => void) {
|
||||
return [
|
||||
tc.qualifier({
|
||||
content: "icon",
|
||||
getContent: (row) => {
|
||||
const user = {
|
||||
email: row.email,
|
||||
personalization: row.personal_name
|
||||
? { name: row.personal_name }
|
||||
: undefined,
|
||||
} as User;
|
||||
return (props) => <UserAvatar user={user} size={props.size} />;
|
||||
},
|
||||
content: "avatar-user",
|
||||
getInitials: (row) => getInitials(row.personal_name, row.email),
|
||||
selectable: false,
|
||||
}),
|
||||
tc.column("email", {
|
||||
header: "Name",
|
||||
weight: 22,
|
||||
minWidth: 140,
|
||||
cell: renderNameColumn,
|
||||
}),
|
||||
tc.column("groups", {
|
||||
header: "Groups",
|
||||
weight: 24,
|
||||
minWidth: 200,
|
||||
enableSorting: false,
|
||||
cell: (value, row) => (
|
||||
<GroupsCell groups={value} user={row} onMutate={onMutate} />
|
||||
@@ -103,16 +97,19 @@ function buildColumns(onMutate: () => void) {
|
||||
tc.column("role", {
|
||||
header: "Account Type",
|
||||
weight: 16,
|
||||
minWidth: 180,
|
||||
cell: (_value, row) => <UserRoleCell user={row} onMutate={onMutate} />,
|
||||
}),
|
||||
tc.column("status", {
|
||||
header: "Status",
|
||||
weight: 14,
|
||||
minWidth: 100,
|
||||
cell: renderStatusColumn,
|
||||
}),
|
||||
tc.column("updated_at", {
|
||||
header: "Last Updated",
|
||||
weight: 14,
|
||||
minWidth: 100,
|
||||
cell: renderLastUpdatedColumn,
|
||||
}),
|
||||
tc.actions({
|
||||
@@ -222,6 +219,7 @@ export default function UsersTable({
|
||||
data={filteredUsers}
|
||||
columns={columns}
|
||||
getRowId={(row) => row.id ?? row.email}
|
||||
qualifier="avatar"
|
||||
pageSize={PAGE_SIZE}
|
||||
searchTerm={searchTerm}
|
||||
emptyState={
|
||||
|
||||
43
web/src/refresh-pages/admin/UsersPage/utils.test.ts
Normal file
43
web/src/refresh-pages/admin/UsersPage/utils.test.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { getInitials } from "./utils";
|
||||
|
||||
describe("getInitials", () => {
|
||||
it("returns first letters of first two name parts", () => {
|
||||
expect(getInitials("Alice Smith", "alice@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("returns first two chars of a single-word name", () => {
|
||||
expect(getInitials("Alice", "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("handles three-word names (uses first two)", () => {
|
||||
expect(getInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
|
||||
});
|
||||
|
||||
it("falls back to email local part with dot separator", () => {
|
||||
expect(getInitials(null, "alice.smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with underscore separator", () => {
|
||||
expect(getInitials(null, "alice_smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with hyphen separator", () => {
|
||||
expect(getInitials(null, "alice-smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("uses first two chars of email local if no separator", () => {
|
||||
expect(getInitials(null, "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("returns ? for empty email local part", () => {
|
||||
expect(getInitials(null, "@example.com")).toBe("?");
|
||||
});
|
||||
|
||||
it("uppercases the result", () => {
|
||||
expect(getInitials("john doe", "jd@test.com")).toBe("JD");
|
||||
});
|
||||
|
||||
it("trims whitespace from name", () => {
|
||||
expect(getInitials(" Alice Smith ", "a@test.com")).toBe("AS");
|
||||
});
|
||||
});
|
||||
23
web/src/refresh-pages/admin/UsersPage/utils.ts
Normal file
23
web/src/refresh-pages/admin/UsersPage/utils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Derive display initials from a user's name or email.
|
||||
*
|
||||
* - If a name is provided, uses the first letter of the first two words.
|
||||
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
|
||||
* - Returns at most 2 uppercase characters.
|
||||
*/
|
||||
export function getInitials(name: string | null, email: string): string {
|
||||
if (name) {
|
||||
const parts = name.trim().split(/\s+/);
|
||||
if (parts.length >= 2) {
|
||||
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
|
||||
}
|
||||
return name.slice(0, 2).toUpperCase();
|
||||
}
|
||||
const local = email.split("@")[0];
|
||||
if (!local) return "?";
|
||||
const parts = local.split(/[._-]/);
|
||||
if (parts.length >= 2) {
|
||||
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
|
||||
}
|
||||
return local.slice(0, 2).toUpperCase();
|
||||
}
|
||||
54
web/tests/e2e/chat/chat_session_not_found.spec.ts
Normal file
54
web/tests/e2e/chat/chat_session_not_found.spec.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const NON_EXISTENT_CHAT_ID = "00000000-0000-0000-0000-000000000000";
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Chat session not found (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
test("should show 404 page for a non-existent chat session", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto(`/app?chatId=${NON_EXISTENT_CHAT_ID}`);
|
||||
|
||||
await expect(page.getByText("Chat not found")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await expect(
|
||||
page.getByText("This chat session doesn't exist or has been deleted.")
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByRole("link", { name: "Start a new chat" })
|
||||
).toBeVisible();
|
||||
|
||||
// Sidebar should still be visible
|
||||
await expect(page.getByTestId("AppSidebar/new-session")).toBeVisible();
|
||||
|
||||
const container = page.locator("[data-main-container]");
|
||||
await expect(container).toBeVisible();
|
||||
await expectElementScreenshot(container, {
|
||||
name: `chat-session-not-found-${theme}`,
|
||||
});
|
||||
});
|
||||
|
||||
test("should navigate to /app when clicking Start a new chat", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto(`/app?chatId=${NON_EXISTENT_CHAT_ID}`);
|
||||
|
||||
await expect(page.getByText("Chat not found")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await page.getByRole("link", { name: "Start a new chat" }).click();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await expect(page).toHaveURL("/app");
|
||||
await expect(page.getByText("Chat not found")).toBeHidden();
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.describe(`Chat Input Focus Retention`, () => {
|
||||
test.beforeEach(async ({ page }, testInfo) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
});
|
||||
|
||||
test("clicking empty space retains focus on chat input", async ({ page }) => {
|
||||
const textarea = page.locator("#onyx-chat-input-textarea");
|
||||
await textarea.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Focus the textarea and type something
|
||||
await textarea.focus();
|
||||
await textarea.fill("test message");
|
||||
await expect(textarea).toBeFocused();
|
||||
|
||||
// Click on the main container's empty space (top-left corner)
|
||||
const container = page.locator("[data-main-container]");
|
||||
await container.click({ position: { x: 10, y: 10 } });
|
||||
|
||||
// Focus should remain on the textarea
|
||||
await expect(textarea).toBeFocused();
|
||||
});
|
||||
|
||||
test("clicking interactive elements still moves focus away", async ({
|
||||
page,
|
||||
}) => {
|
||||
const textarea = page.locator("#onyx-chat-input-textarea");
|
||||
await textarea.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Focus the textarea
|
||||
await textarea.focus();
|
||||
await expect(textarea).toBeFocused();
|
||||
|
||||
// Click on an interactive element inside the container
|
||||
const button = page.locator("[data-main-container] button").first();
|
||||
await button.waitFor({ state: "visible", timeout: 5000 });
|
||||
await button.click();
|
||||
|
||||
// Focus should have moved away from the textarea
|
||||
await expect(textarea).not.toBeFocused();
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user