Compare commits

..

20 Commits

Author SHA1 Message Date
Raunak Bhagat
aa11813cc0 feat: UserAvatar (#9527) 2026-03-21 02:05:00 +00:00
Evan Lohn
6235f49b49 fix: csv test with newlines (#9534) 2026-03-21 01:30:11 +00:00
Evan Lohn
fd6a110794 feat: installer invocable from other bash script (#9531) 2026-03-21 01:18:20 +00:00
Jamison Lahman
bd42c459d6 chore(fe): update memories dropdown padding (#9526) 2026-03-20 23:38:32 +00:00
Danelegend
aede532e63 fix(chat): Cache plaintext file results (#9511) 2026-03-20 23:21:12 +00:00
Evan Lohn
068ac543ad fix: deadlock in multitenant test (#9530) 2026-03-20 23:05:20 +00:00
Bo-Onyx
30e7a831a5 feat(hook): Add hook management API (#9513)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 22:53:59 +00:00
Evan Lohn
276261c96d fix: windows installer (#9507) 2026-03-20 22:53:46 +00:00
Bo-Onyx
205f1410e4 chore(hook): Hook executor. (#9467) 2026-03-20 22:47:01 +00:00
Bo-Onyx
a93d154c27 feat(hook): improve on hook point definition (#9522) 2026-03-20 22:20:42 +00:00
Jamison Lahman
1361879bd0 fix(fe): clicking outside chat area keeps chat input focused (#9521) 2026-03-20 19:22:11 +00:00
Justin Tahara
c58cc320b2 feat(tf): Port over WAF updates (#9520) 2026-03-20 18:45:09 +00:00
Jamison Lahman
461350958a fix(fe): dim project name in sidebar color (#9519) 2026-03-20 17:47:49 +00:00
Raunak Bhagat
50dde0be1a chore: edit AGENTS.md and CLAUDE.md files (#9486) 2026-03-20 00:59:30 +00:00
acaprau
199e1df453 feat(opensearch): Add functions for keyword and semantic retrieval (#9479)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-20 00:48:01 +00:00
Justin Tahara
996b674840 feat(backend): Adding procps (#9509) 2026-03-19 23:26:36 +00:00
Justin Tahara
5413723ccc feat(ods): Rerun run-ci workflow (#9501) 2026-03-19 22:11:59 +00:00
Evan Lohn
9660056a51 fix: drive rate limit retry (#9498) 2026-03-19 21:32:08 +00:00
Fizza Mukhtar
3105177238 fix(llm): don't send tool_choice when no tools are provided (#9224) 2026-03-19 21:26:46 +00:00
Evan Lohn
24bb4bda8b feat: windows installer and install improvements (#9476) 2026-03-19 20:47:44 +00:00
80 changed files with 4841 additions and 1059 deletions

279
AGENTS.md
View File

@@ -167,284 +167,7 @@ web/
## Frontend Standards
### 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
### 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
### 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
### 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
### 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
### 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
### 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
### 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
### 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
### 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
### 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
### 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
## Database & Migrations

View File

@@ -47,6 +47,8 @@ RUN apt-get update && \
gcc \
nano \
vim \
# Install procps so kubernetes exec sessions can use ps aux for debugging
procps \
libjemalloc2 \
&& \
rm -rf /var/lib/apt/lists/* && \

View File

@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if not MULTI_TENANT and HOOK_ENABLED:
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",

View File

@@ -30,6 +30,8 @@ from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import plaintext_file_name_for_id
from onyx.file_store.utils import store_plaintext
from onyx.kg.models import KGException
from onyx.kg.setup.kg_default_entity_definitions import (
populate_missing_default_entity_types__commit,
@@ -289,6 +291,33 @@ def process_kg_commands(
raise KGException("KG setup done")
def _get_or_extract_plaintext(
file_id: str,
extract_fn: Callable[[], str],
) -> str:
"""Load cached plaintext for a file, or extract and store it.
Tries to read pre-stored plaintext from the file store. On a miss,
calls extract_fn to produce the text, then stores the result so
future calls skip the expensive extraction.
"""
file_store = get_default_file_store()
plaintext_key = plaintext_file_name_for_id(file_id)
# Try cached plaintext first.
try:
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()
if content_text:
store_plaintext(file_id, content_text)
return content_text
@log_function_time(print_only=True)
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
@@ -303,12 +332,23 @@ def load_chat_file(
file_type = ChatFileType(file_descriptor["type"])
if file_type.is_text_file():
try:
content_text = extract_file_text(
file_id = file_descriptor["id"]
def _extract() -> str:
return extract_file_text(
file=file_io,
file_name=file_descriptor.get("name") or "",
break_on_unprocessable=False,
)
# Use the user_file_id as cache key when available (matches what
# the celery indexing worker stores), otherwise fall back to the
# file store id (covers code-interpreter-generated files, etc.).
user_file_id_str = file_descriptor.get("user_file_id")
cache_key = user_file_id_str or file_id
try:
content_text = _get_or_extract_plaintext(cache_key, _extract)
except Exception as e:
logger.warning(
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"

View File

@@ -157,9 +157,7 @@ def _execute_single_retrieval(
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
results = _execute_with_retry(retrieval_function(**request_kwargs))
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")

View File

@@ -2,6 +2,7 @@ import time
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
if DISABLE_VECTOR_DB:
return None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from contextlib import AbstractContextManager
@@ -1062,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
f"Body: {get_new_body_without_vectors(body)}\n"
f"Search pipeline ID: {search_pipeline_id}\n"
f"Phase took: {phase_took}\n"
f"Profile: {profile}\n"
f"Profile: {json.dumps(profile, indent=2)}\n"
)
if timed_out:
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."

View File

@@ -950,7 +950,86 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_pipeline_id=normalization_pipeline_name,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
# Good place for a breakpoint to inspect the search hits if you have
# "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_keyword_search_query(
query_text=query,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_embedding,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights

View File

@@ -404,12 +404,170 @@ class DocumentQuery:
DocumentQuery._get_match_highlights_configuration()
)
# Explain is for scoring breakdowns.
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
return final_hybrid_search_body
@staticmethod
def get_keyword_search_query(
query_text: str,
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final keyword search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_text: The text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the keyword search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final keyword search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
keyword_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
keyword_search_query = (
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text, search_filters=keyword_search_filters
)
)
final_keyword_search_query: dict[str, Any] = {
"query": keyword_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_keyword_search_query["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
if not OPENSEARCH_PROFILING_DISABLED:
final_keyword_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_keyword_search_query["explain"] = True
return final_keyword_search_query
@staticmethod
def get_semantic_search_query(
query_embedding: list[float],
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final semantic search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_embedding: The vector embedding of the text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the semantic search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final semantic search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
semantic_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
semantic_search_query = (
DocumentQuery._get_content_vector_similarity_search_query(
query_embedding,
vector_candidates=num_hits,
search_filters=semantic_search_filters,
)
)
final_semantic_search_query: dict[str, Any] = {
"query": semantic_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_semantic_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_semantic_search_query["explain"] = True
return final_semantic_search_query
@staticmethod
def get_random_search_query(
tenant_state: TenantState,
@@ -581,8 +739,9 @@ class DocumentQuery:
def _get_content_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
return {
query = {
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
@@ -591,11 +750,19 @@ class DocumentQuery:
}
}
if search_filters is not None:
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
"bool": {"filter": search_filters}
}
return query
@staticmethod
def _get_title_content_combined_keyword_search_query(
query_text: str,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
return {
query = {
"bool": {
"should": [
{
@@ -636,10 +803,19 @@ class DocumentQuery:
}
}
},
]
],
# Ensure at least one term from the query is present in the
# document. This defaults to 1, unless a filter or must clause
# is supplied, in which case it defaults to 0.
"minimum_should_match": 1,
}
}
if search_filters is not None:
query["bool"]["filter"] = search_filters
return query
@staticmethod
def _get_search_filters(
tenant_state: TenantState,

View File

@@ -88,6 +88,7 @@ class OnyxErrorCode(Enum):
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:

View File

@@ -23,45 +23,55 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return f"plaintext_{user_file_id}"
def plaintext_file_name_for_id(file_id: str) -> str:
"""Generate a consistent file name for storing plaintext content of a file."""
return f"plaintext_{file_id}"
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
"""
Store plaintext content for a user file in the file store.
Store plaintext content for a file in the file store.
Args:
user_file_id: The ID of the user file
file_id: The ID of the file (user_file or artifact_file)
plaintext_content: The plaintext content to store
Returns:
bool: True if storage was successful, False otherwise
"""
# Skip empty content
if not plaintext_content:
return False
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
plaintext_file_name = plaintext_file_name_for_id(file_id)
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
display_name=f"Plaintext for {file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
return False
# --- Convenience wrappers for callers that use user-file UUIDs ---
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return plaintext_file_name_for_id(str(user_file_id))
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
return store_plaintext(str(user_file_id), plaintext_content)
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
"""Load a file directly from the file store using its file_record ID.

View File

@@ -0,0 +1,330 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is the response payload dict from the customer's endpoint
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
import httpx
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
"""No active hook configured for this hook point."""
class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if outcome.response_payload is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
)
return outcome.response_payload

View File

@@ -42,12 +42,8 @@ class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = (
None # if None in model_fields_set, reset to spec default
)
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None in model_fields_set, reset to spec default
fail_strategy: HookFailStrategy | None = None
timeout_seconds: float | None = Field(default=None, gt=0)
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
@@ -60,6 +56,14 @@ class HookUpdateRequest(BaseModel):
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
raise ValueError(
"fail_strategy cannot be null; omit the field to leave it unchanged."
)
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
raise ValueError(
"timeout_seconds cannot be null; omit the field to leave it unchanged."
)
return self
@@ -90,38 +94,28 @@ class HookResponse(BaseModel):
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
is_reachable: bool | None
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateStatus(str, Enum):
passed = "passed" # server responded (any status except 401/403)
auth_failed = "auth_failed" # server responded with 401 or 403
timeout = (
"timeout" # TCP connected, but read/write timed out (server exists but slow)
)
cannot_connect = "cannot_connect" # could not connect to the server
class HookValidateResponse(BaseModel):
success: bool
status: HookValidateStatus
error_message: str | None = None
# ---------------------------------------------------------------------------
# Health models
# ---------------------------------------------------------------------------
class HookHealthStatus(str, Enum):
healthy = "healthy" # green — reachable, no failures in last 1h
degraded = "degraded" # yellow — reachable, failures in last 1h
unreachable = "unreachable" # red — is_reachable=false or null
class HookFailureRecord(BaseModel):
class HookExecutionRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime
class HookHealthResponse(BaseModel):
status: HookHealthStatus
recent_failures: list[HookFailureRecord] = Field(
default_factory=list,
description="Last 10 failures, newest first",
max_length=10,
)

View File

@@ -1,6 +1,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import ClassVar
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -13,22 +14,25 @@ _REQUIRED_ATTRS = (
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
"payload_model",
"response_model",
)
class HookPointSpec(ABC):
class HookPointSpec:
"""Static metadata and contract for a pipeline hook point.
This is NOT a regular class meant for direct instantiation by callers.
Each concrete subclass represents exactly one hook point and is instantiated
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
should ever create instances directly — use get_hook_point_spec() or
get_all_specs() from the registry instead.
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
get_hook_point_spec() or get_all_specs() from the registry over direct
instantiation.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Subclasses must define all attributes as class-level constants.
payload_model and response_model must be Pydantic BaseModel subclasses;
input_schema and output_schema are derived from them automatically.
"""
hook_point: HookPoint
@@ -39,21 +43,33 @@ class HookPointSpec(ABC):
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
payload_model: ClassVar[type[BaseModel]]
response_model: ClassVar[type[BaseModel]]
# Computed once at class definition time from payload_model / response_model.
input_schema: ClassVar[dict[str, Any]]
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every concrete subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
"""
super().__init_subclass__(**kwargs)
# Skip intermediate abstract subclasses — they may still be partially defined.
if getattr(cls, "__abstractmethods__", None):
return
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
@property
@abstractmethod
def input_schema(self) -> dict[str, Any]:
"""JSON schema describing the request payload sent to the customer's endpoint."""
@property
@abstractmethod
def output_schema(self) -> dict[str, Any]:
"""JSON schema describing the expected response from the customer's endpoint."""
for attr in ("payload_model", "response_model"):
val = getattr(cls, attr, None)
if val is None or not (
isinstance(val, type) and issubclass(val, BaseModel)
):
raise TypeError(
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
)
cls.input_schema = cls.payload_model.model_json_schema()
cls.output_schema = cls.response_model.model_json_schema()

View File

@@ -1,10 +1,19 @@
from typing import Any
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
# TODO(@Bo-Onyx): define payload and response fields
class DocumentIngestionPayload(BaseModel):
pass
class DocumentIngestionResponse(BaseModel):
pass
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
@@ -18,12 +27,5 @@ class DocumentIngestionSpec(HookPointSpec):
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define input schema
return {"type": "object", "properties": {}}
@property
def output_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define output schema
return {"type": "object", "properties": {}}
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -1,10 +1,39 @@
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
query: str = Field(description="The raw query string exactly as the user typed it.")
user_email: str | None = Field(
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
)
class QueryProcessingResponse(BaseModel):
# Intentionally permissive — customer endpoints may return extra fields.
query: str | None = Field(
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, or absent = reject the query."
),
)
rejection_message: str | None = Field(
default=None,
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
)
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
@@ -37,47 +66,5 @@ class QueryProcessingSpec(HookPointSpec):
)
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The raw query string exactly as the user typed it.",
},
"user_email": {
"type": ["string", "null"],
"description": "Email of the user submitting the query, or null if unauthenticated.",
},
"chat_session_id": {
"type": "string",
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
},
},
"required": ["query", "user_email", "chat_session_id"],
"additionalProperties": False,
}
@property
def output_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": ["string", "null"],
"description": (
"The (optionally modified) query to use. "
"Set to null to reject the query."
),
},
"rejection_message": {
"type": ["string", "null"],
"description": (
"Message shown to the user when query is null. "
"Falls back to a generic message if not provided."
),
},
},
"required": ["query"],
}
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -0,0 +1,5 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -530,6 +530,11 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:
optional_kwargs["tool_choice"] = tool_choice
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
@@ -538,7 +543,6 @@ class LitellmLLM(LLM):
custom_llm_provider=self._custom_llm_provider or None,
messages=messages,
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,

View File

@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -453,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -0,0 +1,453 @@
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import create_hook__no_commit
from onyx.db.hook import delete_hook__no_commit
from onyx.db.hook import get_hook_by_id
from onyx.db.hook import get_hook_execution_logs
from onyx.db.hook import get_hooks
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.api_dependencies import require_hook_enabled
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookExecutionRecord
from onyx.hooks.models import HookPointMetaResponse
from onyx.hooks.models import HookResponse
from onyx.hooks.models import HookUpdateRequest
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
logger = setup_logger()
# ---------------------------------------------------------------------------
# SSRF protection
# ---------------------------------------------------------------------------
def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
return HookResponse(
id=hook.id,
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
is_reachable=hook.is_reachable,
creator_email=(
creator_email
if creator_email is not None
else (hook.creator.email if hook.creator else None)
),
created_at=hook.created_at,
updated_at=hook.updated_at,
)
def _get_hook_or_404(
db_session: Session,
hook_id: int,
include_creator: bool = False,
) -> Hook:
hook = get_hook_by_id(
db_session=db_session,
hook_id=hook_id,
include_creator=include_creator,
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
return hook
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
"""Raise an appropriate OnyxError for a non-passed validation result."""
if validation.status == HookValidateStatus.auth_failed:
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
if validation.status == HookValidateStatus.timeout:
raise OnyxError(
OnyxErrorCode.GATEWAY_TIMEOUT,
f"Endpoint validation failed: {validation.error_message}",
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Endpoint validation failed: {validation.error_message}",
)
def _validate_endpoint(
endpoint_url: str,
api_key: str | None,
timeout_seconds: float,
) -> HookValidateResponse:
"""Check whether endpoint_url is reachable by sending an empty POST request.
We use POST since hook endpoints expect POST requests. The server will typically
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
the server is up and routable. A 401/403 response returns auth_failed
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
response = client.post(endpoint_url, headers=headers)
if response.status_code in (401, 403):
return HookValidateResponse(
status=HookValidateStatus.auth_failed,
error_message=f"Authentication failed (HTTP {response.status_code})",
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.timeout,
error_message="Endpoint timed out — consider increasing timeout_seconds.",
)
except Exception as exc:
logger.warning(
"Hook endpoint validation: connection error for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# ---------------------------------------------------------------------------
# Routers
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/admin/hooks")
# ---------------------------------------------------------------------------
# Hook endpoints
# ---------------------------------------------------------------------------
@router.get("/specs")
def get_hook_point_specs(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
) -> list[HookPointMetaResponse]:
return [
HookPointMetaResponse(
hook_point=spec.hook_point,
display_name=spec.display_name,
description=spec.description,
docs_url=spec.docs_url,
input_schema=spec.input_schema,
output_schema=spec.output_schema,
default_timeout_seconds=spec.default_timeout_seconds,
default_fail_strategy=spec.default_fail_strategy,
fail_hard_description=spec.fail_hard_description,
)
for spec in get_all_specs()
]
@router.get("")
def list_hooks(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookResponse]:
hooks = get_hooks(db_session=db_session, include_creator=True)
return [_hook_to_response(h) for h in hooks]
@router.post("")
def create_hook(
req: HookCreateRequest,
user: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
endpoint_url=req.endpoint_url,
api_key=api_key,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
hook = create_hook__no_commit(
db_session=db_session,
name=req.name,
hook_point=req.hook_point,
endpoint_url=req.endpoint_url,
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)
@router.get("/{hook_id}")
def get_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
return _hook_to_response(hook)
@router.patch("/{hook_id}")
def update_hook(
hook_id: int,
req: HookUpdateRequest,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
endpoint is re-validated using the effective values. For active hooks the update is
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
the update goes through regardless and is_reachable is updated to reflect the result.
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
"""
# api_key: UNSET = no change, None = clear, value = update
api_key: str | None | UnsetType
if "api_key" not in req.model_fields_set:
api_key = UNSET
elif req.api_key is None:
api_key = None
else:
api_key = req.api_key.get_secret_value()
endpoint_url_changing = "endpoint_url" in req.model_fields_set
api_key_changing = not isinstance(api_key, UnsetType)
timeout_changing = "timeout_seconds" in req.model_fields_set
validated_is_reachable: bool | None = None
if endpoint_url_changing or api_key_changing or timeout_changing:
existing = _get_hook_or_404(db_session, hook_id)
effective_url: str = (
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
)
effective_api_key: str | None = (
(api_key if not isinstance(api_key, UnsetType) else None)
if api_key_changing
else (
existing.api_key.get_value(apply_mask=False)
if existing.api_key
else None
)
)
effective_timeout: float = (
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
)
validation = _validate_endpoint(
endpoint_url=effective_url,
api_key=effective_api_key,
timeout_seconds=effective_timeout,
)
if existing.is_active and validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
validated_is_reachable = validation.status == HookValidateStatus.passed
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
name=req.name,
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
api_key=api_key,
fail_strategy=req.fail_strategy,
timeout_seconds=req.timeout_seconds,
is_reachable=validated_is_reachable,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.delete("/{hook_id}")
def delete_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> None:
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
db_session.commit()
@router.post("/{hook_id}/activate")
def activate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
# Persist is_reachable=False in a separate session so the request
# session has no commits on the failure path and the transaction
# boundary stays clean.
if hook.is_reachable is not False:
with get_session_with_current_tenant() as side_session:
update_hook__no_commit(
db_session=side_session, hook_id=hook_id, is_reachable=False
)
side_session.commit()
_raise_for_validation_failure(validation)
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=True,
is_reachable=True,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.post("/{hook_id}/validate")
def validate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookValidateResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
validation_passed = validation.status == HookValidateStatus.passed
if hook.is_reachable != validation_passed:
update_hook__no_commit(
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
)
db_session.commit()
return validation
@router.post("/{hook_id}/deactivate")
def deactivate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=False,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
# ---------------------------------------------------------------------------
# Execution log endpoints
# ---------------------------------------------------------------------------
@router.get("/{hook_id}/execution-logs")
def list_hook_execution_logs(
hook_id: int,
limit: int = Query(default=10, ge=1, le=100),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookExecutionRecord]:
_get_hook_or_404(db_session, hook_id)
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
return [
HookExecutionRecord(
error_message=log.error_message,
status_code=log.status_code,
duration_ms=log.duration_ms,
created_at=log.created_at,
)
for log in logs
]

View File

@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
return validated_ip, hostname, port
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
def validate_outbound_http_url(
url: str,
*,
allow_private_network: bool = False,
https_only: bool = False,
) -> str:
"""
Validate a URL that will be used by backend outbound HTTP calls.
Args:
url: The URL to validate.
allow_private_network: If True, skip private/reserved IP checks.
https_only: If True, reject http:// URLs (only https:// is allowed).
Returns:
A normalized URL string with surrounding whitespace removed.
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
parsed = urlparse(normalized_url)
if parsed.scheme not in ("http", "https"):
if https_only:
if parsed.scheme != "https":
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
)
elif parsed.scheme not in ("http", "https"):
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
)

View File

@@ -1640,3 +1640,275 @@ class TestOpenSearchClient:
for k in DocumentChunkWithoutVectors.model_fields
}
)
def test_keyword_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests keyword search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
# Tests that we don't return documents that don't match keywords at
# all, even if they match filters.
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
document_id="private-but-not-relevant-doc-user-a",
chunk_index=0,
content="This text should not match the query at all",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
# Should not match private-but-not-relevant-doc-user-a.
query_text = "document content"
search_body = DocumentQuery.get_keyword_search_query(
query_text=query_text,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# This should be the highest-ranked result, as a higher percentage of
# the content matches the query.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
# Make sure there is some kind of match highlight.
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
assert results[1].score < results[0].score
def test_semantic_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests semantic search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
# Make this identical to the query vector to test that this
# result is returned first.
content_vector=_generate_test_vector(0.6),
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
# Make this different from the query vector to test that this
# result is returned second.
content_vector=_generate_test_vector(0.5),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
query_vector = _generate_test_vector(0.6)
search_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_vector,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# We explicitly expect this to be the highest-ranked result.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[0].score == 1.0
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert 0.0 < results[1].score < 1.0

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import os
import subprocess
import sys
import time
import uuid
from collections.abc import Generator
@@ -28,6 +29,9 @@ _BACKEND_DIR = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
_DROP_SCHEMA_MAX_RETRIES = 3
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
# ---------------------------------------------------------------------------
# Helpers
@@ -50,6 +54,39 @@ def _run_script(
)
def _force_drop_schema(engine: Engine, schema: str) -> None:
"""Terminate backends using *schema* then drop it, retrying on deadlock.
Background Celery workers may discover test schemas (they match the
``tenant_`` prefix) and hold locks on tables inside them. A bare
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
first kill their connections and retry if we still hit a deadlock.
"""
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
try:
with engine.connect() as conn:
conn.execute(
text(
"""
SELECT pg_terminate_backend(l.pid)
FROM pg_locks l
JOIN pg_class c ON c.oid = l.relation
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND l.pid != pg_backend_pid()
"""
),
{"schema": schema},
)
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
return
except Exception:
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
raise
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -104,9 +141,7 @@ def tenant_schema_at_head(
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -123,9 +158,7 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -150,9 +183,7 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
# ---------------------------------------------------------------------------

View File

@@ -1,3 +1,5 @@
import csv
import io
import os
from datetime import datetime
from datetime import timedelta
@@ -139,12 +141,12 @@ def test_chat_history_csv_export(
assert headers["Content-Type"] == "text/csv; charset=utf-8"
assert "Content-Disposition" in headers
# Verify CSV content
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 3 # Header + 2 QA pairs
assert "chat_session_id" in csv_content
assert "user_message" in csv_content
assert "ai_response" in csv_content
# Use csv.reader to properly handle newlines inside quoted fields
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 3 # Header + 2 QA pairs
assert csv_rows[0][0] == "chat_session_id"
assert "user_message" in csv_rows[0]
assert "ai_response" in csv_rows[0]
assert "What was the Q1 revenue?" in csv_content
assert "What about Q2 revenue?" in csv_content
@@ -156,5 +158,5 @@ def test_chat_history_csv_export(
end_time=past_end,
user_performing_action=admin_user,
)
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 1 # Only header, no data rows
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 1 # Only header, no data rows

View File

@@ -1,6 +1,5 @@
from typing import Any
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
@@ -11,12 +10,10 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
class IncompleteSpec(HookPointSpec):
hook_point = HookPoint.QUERY_PROCESSING
# missing display_name, description, etc.
# missing display_name, description, payload_model, response_model, etc.
@property
def input_schema(self) -> dict[str, Any]:
return {}
class _Payload(BaseModel):
pass
@property
def output_schema(self) -> dict[str, Any]:
return {}
payload_model = _Payload
response_model = _Payload

View File

@@ -0,0 +1,541 @@
"""Unit tests for the hook executor."""
import json
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
def _make_hook(
*,
is_active: bool = True,
endpoint_url: str | None = "https://hook.example.com/query",
api_key: MagicMock | None = None,
timeout_seconds: float = 5.0,
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
hook.endpoint_url = endpoint_url
hook.api_key = api_key
hook.timeout_seconds = timeout_seconds
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
return hook
def _make_api_key(value: str) -> MagicMock:
api_key = MagicMock()
api_key.get_value.return_value = value
return api_key
def _make_response(
*,
status_code: int = 200,
json_return: Any = _RESPONSE_PAYLOAD,
json_side_effect: Exception | None = None,
) -> MagicMock:
"""Build a response mock with controllable json() behaviour."""
response = MagicMock()
response.status_code = status_code
if json_side_effect is not None:
response.json.side_effect = json_side_effect
else:
response.json.return_value = json_return
return response
def _setup_client(
mock_client_cls: MagicMock,
*,
response: MagicMock | None = None,
side_effect: Exception | None = None,
) -> MagicMock:
"""Wire up the httpx.Client mock and return the inner client.
If side_effect is an httpx.HTTPStatusError, it is raised from
raise_for_status() (matching real httpx behaviour) and post() returns a
response mock with the matching status_code set. All other exceptions are
raised directly from post().
"""
mock_client = MagicMock()
if isinstance(side_effect, httpx.HTTPStatusError):
error_response = MagicMock()
error_response.status_code = side_effect.response.status_code
error_response.raise_for_status.side_effect = side_effect
mock_client.post = MagicMock(return_value=error_response)
else:
mock_client.post = MagicMock(
side_effect=side_effect, return_value=response if not side_effect else None
)
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def db_session() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# Early-exit guards (no HTTP call, no DB writes)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hooks_available,hook",
[
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
hooks_available: bool,
hook: MagicMock | None,
) -> None:
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSkipped)
mock_update.assert_not_called()
mock_log.assert_not_called()
# ---------------------------------------------------------------------------
# Successful HTTP call
# ---------------------------------------------------------------------------
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
hook = _make_hook()
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
"""Deduplication guard: a hook already at is_reachable=True that succeeds
must not trigger a DB write."""
hook = _make_hook(is_reachable=True)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
mock_update.assert_not_called()
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(json_return=["unexpected", "list"]),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-dict" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
"""response.json() raising must be treated as failure with SOFT strategy.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-JSON" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
# ---------------------------------------------------------------------------
# HTTP failure paths
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"exception,fail_strategy,expected_type,expected_is_reachable",
[
# NetworkError → is_reachable=False
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="connect_error_soft",
),
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.HARD,
OnyxError,
False,
id="connect_error_hard",
),
# 401/403 → is_reachable=False (api_key revoked)
pytest.param(
httpx.HTTPStatusError(
"401",
request=MagicMock(),
response=MagicMock(status_code=401, text="Unauthorized"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="auth_401_soft",
),
pytest.param(
httpx.HTTPStatusError(
"403",
request=MagicMock(),
response=MagicMock(status_code=403, text="Forbidden"),
),
HookFailStrategy.HARD,
OnyxError,
False,
id="auth_403_hard",
),
# TimeoutException → no is_reachable write (None)
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="timeout_soft",
),
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.HARD,
OnyxError,
None,
id="timeout_hard",
),
# Other HTTP errors → no is_reachable write (None)
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="http_status_error_soft",
),
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.HARD,
OnyxError,
None,
id="http_status_error_hard",
),
],
)
def test_http_failure_paths(
db_session: MagicMock,
exception: Exception,
fail_strategy: HookFailStrategy,
expected_type: type,
expected_is_reachable: bool | None,
) -> None:
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, expected_type)
if expected_is_reachable is None:
mock_update.assert_not_called()
else:
mock_update.assert_called_once()
_, kwargs = mock_update.call_args
assert kwargs["is_reachable"] is expected_is_reachable
# ---------------------------------------------------------------------------
# Authorization header
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"api_key_value,expect_auth_header",
[
pytest.param("secret-token", True, id="api_key_present"),
pytest.param(None, False, id="api_key_absent"),
],
)
def test_authorization_header(
db_session: MagicMock,
api_key_value: str | None,
expect_auth_header: bool,
) -> None:
api_key = _make_api_key(api_key_value) if api_key_value else None
hook = _make_hook(api_key=api_key)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
_, call_kwargs = mock_client.post.call_args
if expect_auth_header:
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
else:
assert "Authorization" not in call_kwargs["headers"]
# ---------------------------------------------------------------------------
# Persist session failure
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"http_exception,expected_result",
[
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expected_result: Any,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response() if not http_exception else None,
side_effect=http_exception,
)
if expected_result is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == expected_result
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
"""is_reachable update failing (e.g. concurrent hook deletion) must not
prevent the execution log from being written.
Simulates the production failure path: update_hook__no_commit raises
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
between the initial lookup and the reachable update.
"""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
mock_log.assert_called_once()

View File

@@ -37,18 +37,20 @@ def test_input_schema_query_is_string() -> None:
def test_input_schema_user_email_is_nullable() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert "null" in props["user_email"]["type"]
# Pydantic v2 emits anyOf for nullable fields
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
def test_output_schema_query_is_required() -> None:
def test_output_schema_query_is_optional() -> None:
# query defaults to None (absent = reject); not required in the schema
schema = QueryProcessingSpec().output_schema
assert "query" in schema["required"]
assert "query" not in schema.get("required", [])
def test_output_schema_query_is_nullable() -> None:
# null means "reject the query"
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
props = QueryProcessingSpec().output_schema["properties"]
assert "null" in props["query"]["type"]
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
def test_output_schema_rejection_message_is_optional() -> None:

View File

@@ -256,7 +256,6 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -412,7 +411,6 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -1431,3 +1429,36 @@ def test_strip_tool_content_merges_consecutive_tool_results() -> None:
assert "sunny 72F" in merged
assert "tc_2" in merged
assert "headline news" in merged
def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> None:
"""Regression test for providers (e.g. Fireworks) that reject tool_choice=null.
When no tools are provided, tool_choice must not be forwarded to
litellm.completion() at all — not even as None.
"""
messages: LanguageModelInput = [UserMessage(content="Hello!")]
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello!"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = mock_stream_chunks
default_multi_llm.invoke(messages, tools=None)
_, kwargs = mock_completion.call_args
assert (
"tool_choice" not in kwargs
), "tool_choice must not be sent to providers when no tools are provided"

View File

@@ -0,0 +1,278 @@
"""Unit tests for onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_URL = "https://example.com/hook"
_API_KEY = "secret"
_TIMEOUT = 5.0
def _mock_response(status_code: int) -> MagicMock:
response = MagicMock()
response.status_code = status_code
return response
# ---------------------------------------------------------------------------
# _check_ssrf_safety
# ---------------------------------------------------------------------------
class TestCheckSsrfSafety:
def _call(self, url: str) -> None:
_check_ssrf_safety(url)
# --- scheme checks ---
def test_https_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
@pytest.mark.parametrize(
"url", ["http://example.com/hook", "ftp://example.com/hook"]
)
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@pytest.mark.parametrize(
"ip",
[
pytest.param("127.0.0.1", id="loopback"),
pytest.param("10.0.0.1", id="RFC1918-A"),
pytest.param("172.16.0.1", id="RFC1918-B"),
pytest.param("192.168.1.1", id="RFC1918-C"),
pytest.param("169.254.169.254", id="link-local-IMDS"),
pytest.param("100.64.0.1", id="shared-address-space"),
pytest.param("::1", id="IPv6-loopback"),
pytest.param("fc00::1", id="IPv6-ULA"),
pytest.param("fe80::1", id="IPv6-link-local"),
],
)
def test_private_ip_is_blocked(self, ip: str) -> None:
with (
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
pytest.raises(OnyxError) as exc_info,
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
def test_dns_resolution_failure_raises(self) -> None:
import socket
with (
patch(
"onyx.utils.url.socket.getaddrinfo",
side_effect=socket.gaierror("name not found"),
),
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
# _validate_endpoint
# ---------------------------------------------------------------------------
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(status_code)
)
result = self._call()
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
httpx.ReadTimeout("read timeout"),
httpx.WriteTimeout("write timeout"),
httpx.PoolTimeout("pool timeout"),
],
)
def test_read_write_pool_timeout_returns_timeout(
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
# Covers DNS failures, TLS errors, and other connection-level errors.
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectError("name resolution failed")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
ConnectionRefusedError("refused")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key="mykey")
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key=None)
_, kwargs = mock_post.call_args
assert "Authorization" not in kwargs["headers"]
# ---------------------------------------------------------------------------
# _raise_for_validation_failure
# ---------------------------------------------------------------------------
class TestRaiseForValidationFailure:
@pytest.mark.parametrize(
"status, expected_code",
[
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
],
)
def test_raises_correct_error_code(
self, status: HookValidateStatus, expected_code: OnyxErrorCode
) -> None:
validation = HookValidateResponse(status=status, error_message="some error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.error_code == expected_code
def test_auth_failed_passes_error_message_directly(self) -> None:
validation = HookValidateResponse(
status=HookValidateStatus.auth_failed, error_message="bad credentials"
)
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "bad credentials"
@pytest.mark.parametrize(
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
)
def test_timeout_and_cannot_connect_wrap_error_message(
self, status: HookValidateStatus
) -> None:
validation = HookValidateResponse(status=status, error_message="raw error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "Endpoint validation failed: raw error"
# ---------------------------------------------------------------------------
# HookValidateStatus enum string values (API contract)
# ---------------------------------------------------------------------------
class TestHookValidateStatusValues:
@pytest.mark.parametrize(
"status, expected",
[
(HookValidateStatus.passed, "passed"),
(HookValidateStatus.auth_failed, "auth_failed"),
(HookValidateStatus.timeout, "timeout"),
(HookValidateStatus.cannot_connect, "cannot_connect"),
],
)
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
assert status == expected

93
cubic.yaml Normal file
View File

@@ -0,0 +1,93 @@
# yaml-language-server: $schema=https://cubic.dev/schema/cubic-repository-config.schema.json
version: 1
reviews:
enabled: true
sensitivity: medium
incremental_commits: true
check_drafts: false
custom_instructions: |
Use explicit type annotations for variables to enhance code clarity,
especially when moving type hints around in the code.
Use `contributing_guides/best_practices.md` as core review context.
Prefer consistency with existing patterns, fix issues in code you touch,
avoid tacking new features onto muddy interfaces, fail loudly instead of
silently swallowing errors, keep code strictly typed, preserve clear state
boundaries, remove duplicate or dead logic, break up overly long functions,
avoid hidden import-time side effects, respect module boundaries, and favor
correctness-by-construction over relying on callers to use an API correctly.
Reference these files for additional context:
- `contributing_guides/best_practices.md` — Best practices for contributing to the codebase
- `CLAUDE.md` — Project instructions and coding standards
- `backend/alembic/README.md` — Migration guidance, including multi-tenant migration behavior
- `deployment/helm/charts/onyx/values-lite.yaml` — Lite deployment Helm values and service assumptions
- `deployment/docker_compose/docker-compose.onyx-lite.yml` — Lite deployment Docker Compose overlay and disabled service behavior
ignore:
files:
- greptile.json
- cubic.yaml
custom_rules:
- name: TODO format
description: >
Whenever a TODO is added, there must always be an associated name or
ticket in the style of TODO(name): ... or TODO(1234): ...
- name: Frontend standards
description: >
For frontend changes, enforce all standards described in the
web/AGENTS.md file.
include:
- web/**
- desktop/**
- name: No debugging code
description: >
Remove temporary debugging code before merging to production,
especially tenant-specific debugging logs.
- name: No hardcoded booleans
description: >
When hardcoding a boolean variable to a constant value, remove the
variable entirely and clean up all places where it's used rather than
just setting it to a constant.
- name: Multi-tenant awareness
description: >
Code changes must consider both multi-tenant and single-tenant
deployments. In multi-tenant mode, preserve tenant isolation, ensure
tenant context is propagated correctly, and avoid assumptions that only
hold for a single shared schema or globally shared state. In
single-tenant mode, avoid introducing unnecessary tenant-specific
requirements or cloud-only control-plane dependencies.
- name: Onyx lite compatibility
description: >
Code changes must consider both regular Onyx deployments and Onyx lite
deployments. Lite deployments disable the vector DB, Redis, model
servers, and background workers by default, use PostgreSQL-backed
cache/auth/file storage, and rely on the API server to handle
background work. Do not assume those services are available unless the
code path is explicitly limited to full deployments.
- name: OnyxError over HTTPException
description: >
Never raise HTTPException directly in business code. Use
`raise OnyxError(OnyxErrorCode.XXX, "message")` from
`onyx.error_handling.exceptions`. A global FastAPI exception handler
converts OnyxError into structured JSON responses with
{"error_code": "...", "detail": "..."}. Error codes are defined in
`onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors
with dynamic HTTP status codes, use `status_code_override`:
`raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`.
include:
- backend/**/*.py
issues:
fix_with_cubic_buttons: true
pr_comment_fixes: true
fix_commits_to_pr: true

View File

@@ -489,20 +489,18 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -290,25 +290,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -314,21 +314,19 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/sslcerts:/etc/nginx/sslcerts
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
env_file:
- .env.nginx
environment:

View File

@@ -333,25 +333,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -202,20 +202,18 @@ services:
ports:
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -477,7 +477,10 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
# Mount templates read-only; the startup command copies them into
# the writable /etc/nginx/conf.d/ inside the container. This avoids
# "Permission denied" errors on Windows Docker bind mounts.
- ../data/nginx:/nginx-templates:ro
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
# - ../data/certbot/conf:/etc/letsencrypt
# - ../data/certbot/www:/var/www/certbot
@@ -489,12 +492,13 @@ services:
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not receive any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
# PRODUCTION: Change to app.conf.template.prod for production nginx config
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
cache:
image: redis:7.4-alpine

File diff suppressed because it is too large Load Diff

View File

@@ -96,8 +96,8 @@ fi
# When --lite is passed as a flag, lower resource thresholds early (before the
# resource check). When lite is chosen interactively, the thresholds are adjusted
# inside the new-deployment flow, after the resource check has already passed
# with the standard thresholds — which is the safer direction.
# after the resource check has already passed with the standard thresholds —
# which is the safer direction.
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
@@ -110,9 +110,6 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
# Build the -f flags for docker compose.
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
# (used by shutdown/delete-data so users don't need to remember --lite).
# Without the argument, the lite overlay is only included when --lite was
# explicitly passed — preventing install/start from silently staying in
# lite mode just because the file exists on disk from a prior run.
compose_file_args() {
local auto_detect="${1:-false}"
local args="-f docker-compose.yml"
@@ -177,34 +174,52 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
[[ "$NO_PROMPT" = false ]] && [[ -r /dev/tty ]] && [[ -w /dev/tty ]]
}
read_prompt_line() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r REPLY < /dev/tty || REPLY=""
}
read_prompt_char() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r -n 1 REPLY < /dev/tty || REPLY=""
printf "\n" > /dev/tty
}
prompt_or_default() {
local prompt_text="$1"
local default_value="$2"
if is_interactive; then
read -p "$prompt_text" -r REPLY
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
read_prompt_line "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
}
prompt_yn_or_default() {
local prompt_text="$1"
local default_value="$2"
if is_interactive; then
read -p "$prompt_text" -n 1 -r
echo ""
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
read_prompt_char "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
}
confirm_action() {
local description="$1"
prompt_yn_or_default "Install ${description}? (Y/n) [default: Y] " "Y"
if [[ "$REPLY" =~ ^[Nn] ]]; then
print_warning "Skipping: ${description}"
return 1
fi
return 0
}
# Colors for output
@@ -295,8 +310,8 @@ if [ "$DELETE_DATA_MODE" = true ]; then
echo " • All user data and documents"
echo ""
if is_interactive; then
read -p "Are you sure you want to continue? Type 'DELETE' to confirm: " -r
echo ""
prompt_or_default "Are you sure you want to continue? Type 'DELETE' to confirm: " ""
echo "" > /dev/tty
if [ "$REPLY" != "DELETE" ]; then
print_info "Operation cancelled."
exit 0
@@ -395,6 +410,11 @@ fi
if ! command -v docker &> /dev/null; then
if [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; then
print_info "Docker is required but not installed."
if ! confirm_action "Docker Engine"; then
print_error "Docker is required to run Onyx."
exit 1
fi
install_docker_linux
if ! command -v docker &> /dev/null; then
print_error "Docker installation failed."
@@ -411,7 +431,11 @@ if command -v docker &> /dev/null \
&& ! command -v docker-compose &> /dev/null \
&& { [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; }; then
print_info "Docker Compose not found — installing plugin..."
print_info "Docker Compose is required but not installed."
if ! confirm_action "Docker Compose plugin"; then
print_error "Docker Compose is required to run Onyx."
exit 1
fi
COMPOSE_ARCH="$(uname -m)"
COMPOSE_URL="https://github.com/docker/compose/releases/latest/download/docker-compose-linux-${COMPOSE_ARCH}"
COMPOSE_DIR="/usr/local/lib/docker/cli-plugins"
@@ -481,7 +505,7 @@ echo ""
if is_interactive; then
echo -e "${YELLOW}${BOLD}Please acknowledge and press Enter to continue...${NC}"
read -r
read_prompt_line ""
echo ""
else
echo -e "${YELLOW}${BOLD}Running in non-interactive mode - proceeding automatically...${NC}"
@@ -562,10 +586,31 @@ version_compare() {
# Check Docker daemon
if ! docker info &> /dev/null; then
print_error "Docker daemon is not running. Please start Docker."
exit 1
if [[ "$OSTYPE" == "darwin"* ]]; then
print_info "Docker daemon is not running. Starting Docker Desktop..."
open -a Docker
# Wait up to 120 seconds for Docker to be ready
DOCKER_WAIT=0
DOCKER_MAX_WAIT=120
while ! docker info &> /dev/null; do
if [ $DOCKER_WAIT -ge $DOCKER_MAX_WAIT ]; then
print_error "Docker Desktop did not start within ${DOCKER_MAX_WAIT} seconds."
print_info "Please start Docker Desktop manually and re-run this script."
exit 1
fi
printf "\r\033[KWaiting for Docker Desktop to start... (%ds)" "$DOCKER_WAIT"
sleep 2
DOCKER_WAIT=$((DOCKER_WAIT + 2))
done
echo ""
print_success "Docker Desktop is now running"
else
print_error "Docker daemon is not running. Please start Docker."
exit 1
fi
else
print_success "Docker daemon is running"
fi
print_success "Docker daemon is running"
# Check Docker resources
print_step "Verifying Docker resources"
@@ -705,25 +750,48 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
fi
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo " 2) Standard - Full deployment with search, connectors, and RAG"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
print_info "Selected: Standard mode"
;;
*)
LITE_MODE=true
print_info "Selected: Lite mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Handle lite overlay file based on selected mode
if [[ "$LITE_MODE" = true ]]; then
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
print_warning "Existing lite overlay found but --lite was not passed."
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
LITE_MODE=true
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed lite overlay (switching to standard mode)"
fi
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
@@ -745,6 +813,7 @@ print_success "All configuration files ready"
# Set up deployment configuration
print_step "Setting up deployment configs"
ENV_FILE="${INSTALL_ROOT}/deployment/.env"
ENV_TEMPLATE="${INSTALL_ROOT}/deployment/env.template"
# Check if services are already running
if [ -d "${INSTALL_ROOT}/deployment" ] && [ -f "${INSTALL_ROOT}/deployment/docker-compose.yml" ]; then
# Determine compose command
@@ -785,22 +854,22 @@ if [ -f "$ENV_FILE" ]; then
if [ "$REPLY" = "update" ]; then
print_info "Update selected. Which tag would you like to deploy?"
echo ""
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
if [ "$INCLUDE_CRAFT" = true ]; then
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest version"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -852,45 +921,6 @@ else
print_info "No existing .env file found. Setting up new deployment..."
echo ""
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Standard - Full deployment with search, connectors, and RAG"
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
LITE_MODE=true
print_info "Selected: Lite mode"
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
;;
*)
print_info "Selected: Standard mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
# Validate lite + craft combination (could now be set interactively)
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
# Adjust resource expectations for lite mode
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Ask for version
print_info "Which tag would you like to deploy?"
echo ""
@@ -901,18 +931,18 @@ else
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest tag"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -1070,20 +1100,39 @@ fi
export HOST_PORT=$AVAILABLE_PORT
print_success "Using port $AVAILABLE_PORT for nginx"
# Determine if we're using the latest tag or a craft tag (both should force pull)
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
USE_LATEST=true
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
else
print_info "Using 'latest' tag - will force pull and recreate containers"
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
fi
else
USE_LATEST=false
fi
# For pinned version tags, re-download config files from that tag so the
# compose file matches the images being pulled (the initial download used main).
if [[ "$USE_LATEST" = false ]] && [[ "$USE_LOCAL_FILES" = false ]]; then
PINNED_BASE="https://raw.githubusercontent.com/onyx-dot-app/onyx/${CURRENT_IMAGE_TAG}/deployment"
print_info "Fetching config files matching tag ${CURRENT_IMAGE_TAG}..."
if download_file "${PINNED_BASE}/docker_compose/docker-compose.yml" "${INSTALL_ROOT}/deployment/docker-compose.yml" 2>/dev/null; then
download_file "${PINNED_BASE}/data/nginx/app.conf.template" "${INSTALL_ROOT}/data/nginx/app.conf.template" 2>/dev/null || true
download_file "${PINNED_BASE}/data/nginx/run-nginx.sh" "${INSTALL_ROOT}/data/nginx/run-nginx.sh" 2>/dev/null || true
chmod +x "${INSTALL_ROOT}/data/nginx/run-nginx.sh"
if [[ "$LITE_MODE" = true ]]; then
download_file "${PINNED_BASE}/docker_compose/${LITE_COMPOSE_FILE}" \
"${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" 2>/dev/null || true
fi
print_success "Config files updated to match ${CURRENT_IMAGE_TAG}"
else
print_warning "Tag ${CURRENT_IMAGE_TAG} not found on GitHub — using main branch configs"
fi
fi
# Pull Docker images with reduced output
print_step "Pulling Docker images"
print_info "This may take several minutes depending on your internet connection..."

View File

@@ -127,6 +127,7 @@ Inputs (common):
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
- `postgres_username`, `postgres_password`
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
### `vpc`
- Builds a VPC sized for EKS with multiple private and public subnets

View File

@@ -88,6 +88,8 @@ module "waf" {
tags = local.merged_tags
# WAF configuration with sensible defaults
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
geo_restriction_countries = var.waf_geo_restriction_countries

View File

@@ -117,6 +117,18 @@ variable "waf_rate_limit_requests_per_5_minutes" {
default = 2000
}
variable "waf_allowed_ip_cidrs" {
type = list(string)
description = "Optional IPv4 CIDR ranges allowed through the WAF. Leave empty to disable IP allowlisting."
default = []
}
variable "waf_common_rule_set_count_rules" {
type = list(string)
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
default = []
}
variable "waf_api_rate_limit_requests_per_5_minutes" {
type = number
description = "Rate limit for API requests per 5 minutes per IP address"

View File

@@ -1,6 +1,20 @@
locals {
name = var.name
tags = var.tags
name = var.name
tags = var.tags
ip_allowlist_enabled = length(var.allowed_ip_cidrs) > 0
managed_rule_priority = local.ip_allowlist_enabled ? 1 : 0
}
resource "aws_wafv2_ip_set" "allowed_ips" {
count = local.ip_allowlist_enabled ? 1 : 0
name = "${local.name}-allowed-ips"
description = "IP allowlist for ${local.name}"
scope = "REGIONAL"
ip_address_version = "IPV4"
addresses = var.allowed_ip_cidrs
tags = local.tags
}
# AWS WAFv2 Web ACL
@@ -13,10 +27,38 @@ resource "aws_wafv2_web_acl" "main" {
allow {}
}
dynamic "rule" {
for_each = local.ip_allowlist_enabled ? [1] : []
content {
name = "BlockRequestsOutsideAllowedIPs"
priority = 1
action {
block {}
}
statement {
not_statement {
statement {
ip_set_reference_statement {
arn = aws_wafv2_ip_set.allowed_ips[0].arn
}
}
}
}
visibility_config {
cloudwatch_metrics_enabled = true
metric_name = "BlockRequestsOutsideAllowedIPsMetric"
sampled_requests_enabled = true
}
}
}
# AWS Managed Rules - Core Rule Set
rule {
name = "AWSManagedRulesCommonRuleSet"
priority = 1
priority = 1 + local.managed_rule_priority
override_action {
none {}
@@ -26,6 +68,16 @@ resource "aws_wafv2_web_acl" "main" {
managed_rule_group_statement {
name = "AWSManagedRulesCommonRuleSet"
vendor_name = "AWS"
dynamic "rule_action_override" {
for_each = var.common_rule_set_count_rules
content {
name = rule_action_override.value
action_to_use {
count {}
}
}
}
}
}
@@ -39,7 +91,7 @@ resource "aws_wafv2_web_acl" "main" {
# AWS Managed Rules - Known Bad Inputs
rule {
name = "AWSManagedRulesKnownBadInputsRuleSet"
priority = 2
priority = 2 + local.managed_rule_priority
override_action {
none {}
@@ -62,7 +114,7 @@ resource "aws_wafv2_web_acl" "main" {
# Rate Limiting Rule
rule {
name = "RateLimitRule"
priority = 3
priority = 3 + local.managed_rule_priority
action {
block {}
@@ -87,7 +139,7 @@ resource "aws_wafv2_web_acl" "main" {
for_each = length(var.geo_restriction_countries) > 0 ? [1] : []
content {
name = "GeoRestrictionRule"
priority = 4
priority = 4 + local.managed_rule_priority
action {
block {}
@@ -110,7 +162,7 @@ resource "aws_wafv2_web_acl" "main" {
# IP Rate Limiting
rule {
name = "APIRateLimitRule"
priority = 5
priority = 5 + local.managed_rule_priority
action {
block {}
@@ -133,7 +185,7 @@ resource "aws_wafv2_web_acl" "main" {
# SQL Injection Protection
rule {
name = "AWSManagedRulesSQLiRuleSet"
priority = 6
priority = 6 + local.managed_rule_priority
override_action {
none {}
@@ -156,7 +208,7 @@ resource "aws_wafv2_web_acl" "main" {
# Anonymous IP Protection
rule {
name = "AWSManagedRulesAnonymousIpList"
priority = 7
priority = 7 + local.managed_rule_priority
override_action {
none {}

View File

@@ -9,6 +9,18 @@ variable "tags" {
default = {}
}
variable "allowed_ip_cidrs" {
type = list(string)
description = "Optional IPv4 CIDR ranges allowed to reach the application. Leave empty to disable IP allowlisting."
default = []
}
variable "common_rule_set_count_rules" {
type = list(string)
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
default = []
}
variable "rate_limit_requests_per_5_minutes" {
type = number
description = "Rate limit for requests per 5 minutes per IP address"

1
desktop/AGENTS.md Symbolic link
View File

@@ -0,0 +1 @@
../web/AGENTS.md

1
desktop/CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

View File

@@ -65,7 +65,7 @@
},
{
"scope": ["web/**"],
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/STANDARDS.md file."
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
},
{
"scope": [],
@@ -85,7 +85,7 @@
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
],
"files": [

View File

@@ -17,6 +17,7 @@ import (
type RunCIOptions struct {
DryRun bool
Yes bool
Rerun bool
}
// NewRunCICommand creates a new run-ci command
@@ -49,6 +50,7 @@ Example usage:
cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, "Perform all local operations but skip pushing to remote and creating PRs")
cmd.Flags().BoolVar(&opts.Yes, "yes", false, "Skip confirmation prompts and automatically proceed")
cmd.Flags().BoolVar(&opts.Rerun, "rerun", false, "Update an existing CI PR with the latest fork changes to re-trigger CI")
return cmd
}
@@ -107,19 +109,44 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
log.Fatalf("PR #%s is not from a fork - CI should already run automatically", prNumber)
}
// Confirm before proceeding
if !opts.Yes {
if !prompt.Confirm(fmt.Sprintf("Create CI branch for PR #%s? (yes/no): ", prNumber)) {
log.Info("Exiting...")
return
}
}
// Create the CI branch
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
// Check if a CI PR already exists for this branch
existingPRURL, err := findExistingCIPR(ciBranch)
if err != nil {
log.Fatalf("Failed to check for existing CI PR: %v", err)
}
if existingPRURL != "" && !opts.Rerun {
log.Infof("A CI PR already exists for #%s: %s", prNumber, existingPRURL)
log.Info("Run with --rerun to update it with the latest fork changes and re-trigger CI.")
return
}
if opts.Rerun && existingPRURL == "" {
log.Warn("--rerun was specified but no existing open CI PR was found. A new PR will be created.")
}
if existingPRURL != "" && opts.Rerun {
log.Infof("Existing CI PR found: %s", existingPRURL)
log.Info("Will update the CI branch with the latest fork changes to re-trigger CI.")
}
// Confirm before proceeding
if !opts.Yes {
action := "Create CI branch"
if existingPRURL != "" {
action = "Update existing CI branch"
}
if !prompt.Confirm(fmt.Sprintf("%s for PR #%s? (yes/no): ", action, prNumber)) {
log.Info("Exiting...")
return
}
}
// Fetch the fork's branch
if forkRepo == "" {
log.Fatalf("Could not determine fork repository - headRepositoryOwner or headRepository.name is empty")
@@ -158,7 +185,11 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
if opts.DryRun {
log.Warnf("[DRY RUN] Would push CI branch: %s", ciBranch)
log.Warnf("[DRY RUN] Would create PR: %s", prTitle)
if existingPRURL == "" {
log.Warnf("[DRY RUN] Would create PR: %s", prTitle)
} else {
log.Warnf("[DRY RUN] Would update existing PR: %s", existingPRURL)
}
// Switch back to original branch
if err := git.RunCommand("switch", "--quiet", originalBranch); err != nil {
log.Warnf("Failed to switch back to original branch: %v", err)
@@ -176,6 +207,17 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
log.Fatalf("Failed to push CI branch: %v", err)
}
if existingPRURL != "" {
// PR already exists - force push is enough to re-trigger CI
log.Infof("Switching back to original branch: %s", originalBranch)
if err := git.RunCommand("switch", "--quiet", originalBranch); err != nil {
log.Warnf("Failed to switch back to original branch: %v", err)
}
log.Infof("CI PR updated successfully: %s", existingPRURL)
log.Info("The force push will re-trigger CI. Remember to close (not merge) this PR after CI completes!")
return
}
// Create PR using GitHub CLI
log.Info("Creating PR...")
prURL, err := createCIPR(ciBranch, prInfo.BaseRefName, prTitle, prBody)
@@ -217,6 +259,39 @@ func getPRInfo(prNumber string) (*PRInfo, error) {
return &prInfo, nil
}
// findExistingCIPR checks if an open PR already exists for the given CI branch.
// Returns the PR URL if found, or empty string if not.
func findExistingCIPR(headBranch string) (string, error) {
cmd := exec.Command("gh", "pr", "list",
"--head", headBranch,
"--state", "open",
"--json", "url",
)
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("%w: %s", err, string(exitErr.Stderr))
}
return "", err
}
var prs []struct {
URL string `json:"url"`
}
if err := json.Unmarshal(output, &prs); err != nil {
log.Debugf("Failed to parse PR list JSON: %v (raw: %s)", err, string(output))
return "", fmt.Errorf("failed to parse PR list: %w", err)
}
if len(prs) == 0 {
log.Debugf("No existing open PRs found for branch %s", headBranch)
return "", nil
}
log.Debugf("Found existing PR for branch %s: %s", headBranch, prs[0].URL)
return prs[0].URL, nil
}
// createCIPR creates a pull request for CI using the GitHub CLI
func createCIPR(headBranch, baseBranch, title, body string) (string, error) {
cmd := exec.Command("gh", "pr", "create",

540
web/AGENTS.md Normal file
View File

@@ -0,0 +1,540 @@
# Frontend Standards
This file is the single source of truth for frontend coding standards across all Onyx frontend
projects (including, but not limited to, `/web`, `/desktop`).
# Components
UI components are spread across several directories while the codebase migrates to Opal:
- **`web/lib/opal/src/`** — The Opal design system. Preferred for all new components.
- **`web/src/refresh-components/`** — Production components not yet migrated to Opal.
- **`web/src/sections/`** — Feature-specific composite components (cards, modals, etc.).
- **`web/src/layouts/`** — Page-level layout components (settings pages, etc.).
**Do NOT use anything from `web/src/components/`** — this directory contains legacy components
that are being phased out. Always prefer Opal first; fall back to `refresh-components` only for
components not yet available in Opal.
## Opal Layouts (`lib/opal/src/layouts/`)
All layout primitives are imported from `@opal/layouts`. They handle sizing, font selection, icon
alignment, and optional inline editing.
```typescript
import { Content, ContentAction, IllustrationContent } from "@opal/layouts";
```
### Content
**Use this for any combination of icon + title + description.**
A two-axis layout component that automatically routes to the correct internal layout
(`ContentXl`, `ContentLg`, `ContentMd`, `ContentSm`) based on `sizePreset` and `variant`:
| sizePreset | variant | Routes to | Layout |
|---|---|---|---|
| `headline` / `section` | `heading` | `ContentXl` | Icon on top (flex-col) |
| `headline` / `section` | `section` | `ContentLg` | Icon inline (flex-row) |
| `main-content` / `main-ui` / `secondary` | `section` / `heading` | `ContentMd` | Compact inline |
| `main-content` / `main-ui` / `secondary` | `body` | `ContentSm` | Body text layout |
```typescript
<Content
sizePreset="main-ui"
variant="section"
icon={SvgSettings}
title="Settings"
description="Manage your preferences"
/>
```
### ContentAction
**Use this when a Content block needs right-side actions** (buttons, badges, icons, etc.).
Wraps `Content` and adds a `rightChildren` slot. Accepts all `Content` props plus:
- `rightChildren`: `ReactNode` — actions rendered on the right
- `paddingVariant`: `SizeVariant` — controls outer padding
```typescript
<ContentAction
sizePreset="main-ui"
variant="section"
icon={SvgUser}
title="John Doe"
description="Admin"
rightChildren={<Button icon={SvgEdit}>Edit</Button>}
/>
```
### IllustrationContent
**Use this for empty states, error pages, and informational placeholders.**
A vertically-stacked, center-aligned layout that pairs a large illustration (7.5rem x 7.5rem)
with a title and optional description.
```typescript
import SvgNoResult from "@opal/illustrations/no-result";
<IllustrationContent
illustration={SvgNoResult}
title="No results found"
description="Try adjusting your search or filters."
/>
```
Props:
- `illustration`: `IconFunctionComponent` — optional, from `@opal/illustrations`
- `title`: `string` — required
- `description`: `string` — optional
## Settings Page Layout (`src/layouts/settings-layouts.tsx`)
**Use this for all admin/settings pages.** Provides a standardized layout with scroll-aware
sticky headers, centered content containers, and responsive behavior.
```typescript
import SettingsLayouts from "@/layouts/settings-layouts";
function MySettingsPage() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgSettings}
title="Account Settings"
description="Manage your account preferences"
rightChildren={<Button>Save</Button>}
>
<InputTypeIn placeholder="Search settings..." />
</SettingsLayouts.Header>
<SettingsLayouts.Body>
<Card>Settings content here</Card>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
```
Sub-components:
- **`SettingsLayouts.Root`** — Wrapper with centered, scrollable container. Width options:
`"sm"` (672px), `"sm-md"` (752px), `"md"` (872px, default), `"lg"` (992px), `"full"` (100%).
- **`SettingsLayouts.Header`** — Sticky header with icon, title, description, optional
`rightChildren` actions, optional `children` below (e.g., search/filter), optional `backButton`,
and optional `separator`. Automatically shows a scroll shadow when scrolled.
- **`SettingsLayouts.Body`** — Content container with consistent padding and vertical spacing.
## Cards (`src/sections/cards/`)
**When building a card that displays information about a specific entity (agent, document set,
file, connector, etc.), add it to `web/src/sections/cards/`.**
Each card is a self-contained component focused on a single entity type. Cards typically include
entity identification (name, avatar, icon), summary information, and quick actions.
```typescript
import AgentCard from "@/sections/cards/AgentCard";
import DocumentSetCard from "@/sections/cards/DocumentSetCard";
import FileCard from "@/sections/cards/FileCard";
```
Guidelines:
- One card per entity type — keep card-specific logic within the card component.
- Cards should be reusable across different pages and contexts.
- Use shared components from `@opal/components`, `@opal/layouts`, and `@/refresh-components`
inside cards — do not duplicate layout or styling logic.
## Button (`components/buttons/button/`)
**Always use the Opal `Button`.** Do not use raw `<button>` elements.
Built on `Interactive.Stateless` > `Interactive.Container`, so it inherits the full color/state
system automatically.
```typescript
import { Button } from "@opal/components/buttons/button/components";
// Labeled button
<Button variant="default" prominence="primary" icon={SvgPlus}>
Create
</Button>
// Icon-only button (omit children)
<Button variant="default" prominence="tertiary" icon={SvgTrash} size="sm" />
```
Key props:
- `variant`: `"default"` | `"action"` | `"danger"` | `"none"`
- `prominence`: `"primary"` | `"secondary"` | `"tertiary"` | `"internal"`
- `size`: `"lg"` | `"md"` | `"sm"` | `"xs"` | `"2xs"` | `"fit"`
- `icon`, `rightIcon`, `children`, `disabled`, `href`, `tooltip`
## Core Primitives (`core/`)
The `core/` directory contains the lowest-level building blocks that power all Opal components.
**Most code should not interface with these directly** — use higher-level components like `Button`,
`Content`, and `ContentAction` instead. These are documented here for understanding, not everyday use.
### Interactive (`core/interactive/`)
The foundational layer for all clickable/interactive surfaces. Defines the color matrix for
hover, active, and disabled states.
- **`Interactive.Stateless`** — Color system for stateless elements (buttons, links). Applies
variant/prominence/state combinations via CSS custom properties.
- **`Interactive.Stateful`** — Color system for stateful elements (toggles, sidebar items, selects).
Uses `state` (`"empty"` | `"filled"` | `"selected"`) instead of prominence.
- **`Interactive.Container`** — Structural box providing height, rounding, padding, and border.
Shared by both Stateless and Stateful. Renders as `<div>`, `<button>`, or `<Link>` depending
on context.
- **`Interactive.Foldable`** — Zero-width collapsible wrapper with CSS grid animation.
### Disabled (`core/disabled/`)
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
interactive descendants.
### Hoverable (`core/animations/`)
A standardized way to provide "opacity-100 on hover" behavior. Instead of manually wiring
`opacity-0 group-hover:opacity-100` with Tailwind, use `Hoverable` for consistent, coordinated
hover-to-reveal patterns.
- **`Hoverable.Root`** — Wraps a hover group. Tracks mouse enter/leave and broadcasts hover
state to descendants via a per-group React context.
- **`Hoverable.Item`** — Marks an element that should appear on hover. Supports two modes:
- **Group mode** (`group` prop provided): visibility driven by a matching `Hoverable.Root`
ancestor. Throws if no matching Root is found.
- **Local mode** (`group` omitted): uses CSS `:hover` on the item itself.
```typescript
import { Hoverable } from "@opal/core";
// Group mode — hovering anywhere on the row reveals the trash icon
<Hoverable.Root group="row">
<div className="flex items-center gap-2">
<span>Row content</span>
<Hoverable.Item group="row" variant="opacity-on-hover">
<SvgTrash />
</Hoverable.Item>
</div>
</Hoverable.Root>
// Local mode — hovering the item itself reveals it
<Hoverable.Item variant="opacity-on-hover">
<SvgTrash />
</Hoverable.Item>
```
# Best Practices
## 1. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
## 2. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
## 3. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
## 4. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
## 5. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
## 6. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
# Stylistic Preferences
## 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
## 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
## 3. Props Interface Extraction
**Extract prop types into their own interface definitions. Keep prop interfaces in the same file
as the component they belong to. Non-prop types (shared models, API response shapes, enums, etc.)
should be placed in a co-located `interfaces.ts` file.**
**Reason:** Prop interfaces are tightly coupled to their component and rarely imported elsewhere,
so co-location keeps things simple. Shared types belong in `interfaces.ts` so they can be
imported without pulling in component code.
```typescript
// ✅ Good — props interface in the same file as the component
// UserCard.tsx
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ✅ Good — shared types in interfaces.ts
// interfaces.ts
export interface User {
id: string
name: string
role: UserRole
}
export type UserRole = "admin" | "member" | "viewer"
// ❌ Bad — inline prop types
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
## 4. Spacing Guidelines
**Prefer padding over margins for spacing. When a library component exposes a padding prop
(e.g., `paddingVariant`), use that prop instead of wrapping it in a `<div>` with padding classes.
If a library component does not expose a padding override and you find yourself adding a wrapper
div for spacing, consider updating the library component to accept one.**
**Reason:** We want to consolidate usage to paddings instead of margins, and minimize wrapper
divs that exist solely for spacing.
```typescript
// ✅ Good — use the component's padding prop
<ContentAction paddingVariant="md" ... />
// ✅ Good — padding utilities when no component prop exists
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad — wrapper div just for spacing
<div className="p-4">
<ContentAction ... />
</div>
// ❌ Bad — margins
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
## 5. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
## 6. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```

1
web/CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

View File

@@ -1,281 +0,0 @@
# Web Standards
This document outlines the coding standards and best practices for the `web` directory Next.js project.
## 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
## 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
## 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
## 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
## 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
## 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
## 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
## 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
## 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
## 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
## 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
## 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).

View File

@@ -56,8 +56,14 @@ function MemoryTagWithTooltip({
side="bottom"
className="bg-background-neutral-00 text-text-01 shadow-md max-w-[17.5rem] p-1"
tooltip={
<Section flexDirection="column" gap={0.25} height="auto">
<div className="p-1 w-full">
<Section
flexDirection="column"
alignItems="start"
padding={0.25}
gap={0.25}
height="auto"
>
<div className="p-1">
<Text as="p" secondaryBody text03>
{memoryText}
</Text>
@@ -66,6 +72,7 @@ function MemoryTagWithTooltip({
icon={SvgAddLines}
title={operationLabel}
sizePreset="secondary"
paddingVariant="sm"
variant="body"
prominence="muted"
rightChildren={

View File

@@ -123,9 +123,6 @@ export interface LLMProviderFormProps {
open?: boolean;
onOpenChange?: (open: boolean) => void;
/** The current default model name for this provider (from the global default). */
defaultModelName?: string;
// Onboarding-specific (only when variant === "onboarding")
onboardingState?: OnboardingState;
onboardingActions?: OnboardingActions;

View File

@@ -20,10 +20,15 @@
"use client";
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
import {
cn,
ensureHrefProtocol,
INTERACTIVE_SELECTOR,
noProp,
} from "@/lib/utils";
import type { Components } from "react-markdown";
import Text from "@/refresh-components/texts/Text";
import { useCallback, useMemo, useState, useEffect } from "react";
import { useCallback, useMemo, useRef, useState, useEffect } from "react";
import { useAppBackground } from "@/providers/AppBackgroundProvider";
import { useTheme } from "next-themes";
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
@@ -532,6 +537,37 @@ function Root({ children, enableBackground }: AppRootProps) {
const { isSafari } = useBrowserInfo();
const isLightMode = resolvedTheme === "light";
const showBackground = hasBackground && enableBackground;
// Track whether the chat input was focused before a mousedown, so we can
// restore focus on mouseup if no text was selected. This preserves
// click-drag text selection while keeping the input focused on plain clicks.
const inputWasFocused = useRef(false);
const handleMouseDown = useCallback(
(event: React.MouseEvent<HTMLDivElement>) => {
const activeEl = document.activeElement;
const isFocused =
activeEl instanceof HTMLElement &&
activeEl.id === "onyx-chat-input-textarea";
const target = event.target;
const isInteractive =
target instanceof HTMLElement && !!target.closest(INTERACTIVE_SELECTOR);
inputWasFocused.current = isFocused && !isInteractive;
},
[]
);
const handleMouseUp = useCallback(() => {
if (!inputWasFocused.current) return;
inputWasFocused.current = false;
const sel = window.getSelection();
if (sel && !sel.isCollapsed) return;
const textarea = document.getElementById("onyx-chat-input-textarea");
// Only restore focus if no other element has grabbed it since mousedown.
if (textarea && document.activeElement !== textarea) {
textarea.focus();
}
}, []);
const horizontalBlurMask = `linear-gradient(
to right,
transparent 0%,
@@ -549,6 +585,8 @@ function Root({ children, enableBackground }: AppRootProps) {
*/
<div
data-main-container
onMouseDown={handleMouseDown}
onMouseUp={handleMouseUp}
className={cn(
"@container flex flex-col h-full w-full relative overflow-hidden",
showBackground && "bg-cover bg-center bg-fixed"

View File

@@ -125,7 +125,7 @@ export const MAX_FILES_TO_SHOW = 3;
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
export const DEFAULT_AVATAR_SIZE_PX = 18;
export const HORIZON_DISTANCE_PX = 800;
export const LOGO_FOLDED_SIZE_PX = 24;
export const LOGO_UNFOLDED_SIZE_PX = 88;

55
web/src/lib/user.test.ts Normal file
View File

@@ -0,0 +1,55 @@
import { getUserInitials } from "@/lib/user";
describe("getUserInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getUserInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getUserInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getUserInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getUserInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getUserInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getUserInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getUserInitials(null, "alice@example.com")).toBe("AL");
});
it("returns null for empty email local part", () => {
expect(getUserInitials(null, "@example.com")).toBeNull();
});
it("uppercases the result", () => {
expect(getUserInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getUserInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
it("returns null for numeric name parts", () => {
expect(getUserInitials("Alice 1st", "x@test.com")).toBeNull();
});
it("returns null for numeric email", () => {
expect(getUserInitials(null, "42@domain.com")).toBeNull();
});
it("falls back to email when name has non-alpha chars", () => {
expect(getUserInitials("A1", "alice@example.com")).toBe("AL");
});
});

View File

@@ -128,3 +128,54 @@ export function getUserDisplayName(user: User | null): string {
// If nothing works, then fall back to anonymous user name
return "Anonymous";
}
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns `null` when no valid alpha initials can be derived.
*/
export function getUserInitials(
name: string | null,
email: string
): string | null {
if (name) {
const words = name.trim().split(/\s+/);
if (words.length >= 2) {
const first = words[0]?.[0];
const second = words[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (name.trim().length >= 1) {
const result = name.trim().slice(0, 2).toUpperCase();
if (/^[A-Z]{1,2}$/.test(result)) return result;
}
}
const local = email.split("@")[0];
if (!local || local.length === 0) return null;
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
const first = parts[0]?.[0];
const second = parts[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (local.length >= 2) {
const result = local.slice(0, 2).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
if (local.length === 1) {
const result = local.toUpperCase();
if (/^[A-Z]$/.test(result)) return result;
}
return null;
}

View File

@@ -13,6 +13,9 @@ import { ALLOWED_URL_PROTOCOLS } from "./constants";
const URI_SCHEME_REGEX = /^[a-zA-Z][a-zA-Z\d+.-]*:/;
const BARE_EMAIL_REGEX = /^[^\s@/]+@[^\s@/:]+\.[^\s@/:]+$/;
export const INTERACTIVE_SELECTOR =
"a, button, input, textarea, select, label, [role='button'], [tabindex]:not([tabindex='-1']), [contenteditable]:not([contenteditable='false'])";
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
}

View File

@@ -4,10 +4,7 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { buildImgUrl } from "@/app/app/components/files/images/utils";
import { OnyxIcon } from "@/components/icons/icons";
import { useSettingsContext } from "@/providers/SettingsProvider";
import {
DEFAULT_AGENT_AVATAR_SIZE_PX,
DEFAULT_AGENT_ID,
} from "@/lib/constants";
import { DEFAULT_AVATAR_SIZE_PX, DEFAULT_AGENT_ID } from "@/lib/constants";
import CustomAgentAvatar from "@/refresh-components/avatars/CustomAgentAvatar";
import Image from "next/image";
@@ -18,7 +15,7 @@ export interface AgentAvatarProps {
export default function AgentAvatar({
agent,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
size = DEFAULT_AVATAR_SIZE_PX,
...props
}: AgentAvatarProps) {
const settings = useSettingsContext();

View File

@@ -4,7 +4,7 @@ import { cn } from "@/lib/utils";
import type { IconProps } from "@opal/types";
import Text from "@/refresh-components/texts/Text";
import Image from "next/image";
import { DEFAULT_AGENT_AVATAR_SIZE_PX } from "@/lib/constants";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import {
SvgActivitySmall,
SvgAudioEqSmall,
@@ -96,7 +96,7 @@ export default function CustomAgentAvatar({
src,
iconName,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
size = DEFAULT_AVATAR_SIZE_PX,
}: CustomAgentAvatarProps) {
if (src) {
return (

View File

@@ -0,0 +1,52 @@
import { SvgUser } from "@opal/icons";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { getUserInitials } from "@/lib/user";
import Text from "@/refresh-components/texts/Text";
import type { User } from "@/lib/types";
export interface UserAvatarProps {
user: User;
size?: number;
}
export default function UserAvatar({
user,
size = DEFAULT_AVATAR_SIZE_PX,
}: UserAvatarProps) {
const initials = getUserInitials(
user.personalization?.name ?? null,
user.email
);
if (!initials) {
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-tint-01"
style={{ width: size, height: size }}
>
<SvgUser size={size * 0.55} className="stroke-text-03" aria-hidden />
</div>
);
}
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
style={{ width: size, height: size }}
>
<Text
inverted
secondaryAction
text05
className="select-none"
style={{ fontSize: size * 0.4 }}
>
{initials}
</Text>
</div>
);
}

View File

@@ -148,14 +148,12 @@ interface ExistingProviderCardProps {
provider: LLMProviderView;
isDefault: boolean;
isLastProvider: boolean;
defaultModelName?: string;
}
function ExistingProviderCard({
provider,
isDefault,
isLastProvider,
defaultModelName,
}: ExistingProviderCardProps) {
const { mutate } = useSWRConfig();
const [isOpen, setIsOpen] = useState(false);
@@ -232,12 +230,7 @@ function ExistingProviderCard({
</Section>
}
/>
{getModalForExistingProvider(
provider,
isOpen,
setIsOpen,
defaultModelName
)}
{getModalForExistingProvider(provider, isOpen, setIsOpen)}
</Card>
</Hoverable.Root>
</>
@@ -453,11 +446,6 @@ export default function LLMConfigurationPage() {
provider={provider}
isDefault={defaultText?.provider_id === provider.id}
isLastProvider={sortedProviders.length === 1}
defaultModelName={
defaultText?.provider_id === provider.id
? defaultText.model_name
: undefined
}
/>
))}
</div>

View File

@@ -26,7 +26,7 @@ import type {
StatusFilter,
StatusCountMap,
} from "./interfaces";
import { getInitials } from "./utils";
import { getUserInitials } from "@/lib/user";
// ---------------------------------------------------------------------------
// Column renderers
@@ -76,7 +76,8 @@ function buildColumns(onMutate: () => void) {
return [
tc.qualifier({
content: "avatar-user",
getInitials: (row) => getInitials(row.personal_name, row.email),
getInitials: (row) =>
getUserInitials(row.personal_name, row.email) ?? "?",
selectable: false,
}),
tc.column("email", {

View File

@@ -1,43 +0,0 @@
import { getInitials } from "./utils";
describe("getInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getInitials(null, "alice@example.com")).toBe("AL");
});
it("returns ? for empty email local part", () => {
expect(getInitials(null, "@example.com")).toBe("?");
});
it("uppercases the result", () => {
expect(getInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
});

View File

@@ -1,23 +0,0 @@
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns at most 2 uppercase characters.
*/
export function getInitials(name: string | null, email: string): string {
if (name) {
const parts = name.trim().split(/\s+/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return name.slice(0, 2).toUpperCase();
}
const local = email.split("@")[0];
if (!local) return "?";
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return local.slice(0, 2).toUpperCase();
}

View File

@@ -35,7 +35,6 @@ export default function AnthropicModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -65,15 +64,10 @@ export default function AnthropicModal({
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
}
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? undefined,
default_model_name:
defaultModelName ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,

View File

@@ -81,7 +81,6 @@ export default function AzureModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -110,11 +109,7 @@ export default function AzureModal({
default_model_name: "",
} as AzureModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_key: existingLlmProvider?.api_key ?? "",
target_uri: buildTargetUri(existingLlmProvider),
};

View File

@@ -315,7 +315,6 @@ export default function BedrockModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -352,11 +351,7 @@ export default function BedrockModal({
},
} as BedrockModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
custom_config: {
AWS_REGION_NAME:
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??

View File

@@ -197,7 +197,6 @@ export default function CustomModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
}: LLMProviderFormProps) {
@@ -210,11 +209,7 @@ export default function CustomModal({
const onClose = () => onOpenChange?.(false);
const initialValues = {
...buildDefaultInitialValues(
existingLlmProvider,
undefined,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider),
...(isOnboarding ? buildOnboardingInitialValues() : {}),
provider: existingLlmProvider?.provider ?? "",
model_configurations: existingLlmProvider?.model_configurations.map(

View File

@@ -192,7 +192,6 @@ export default function LMStudioForm({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -226,11 +225,7 @@ export default function LMStudioForm({
},
} as LMStudioFormValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
LM_STUDIO_API_KEY:

View File

@@ -159,7 +159,6 @@ export default function LiteLLMProxyModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -191,11 +190,7 @@ export default function LiteLLMProxyModal({
default_model_name: "",
} as LiteLLMProxyModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};

View File

@@ -212,7 +212,6 @@ export default function OllamaModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -245,11 +244,7 @@ export default function OllamaModal({
},
} as OllamaModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
OLLAMA_API_KEY:

View File

@@ -35,7 +35,6 @@ export default function OpenAIModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -64,14 +63,9 @@ export default function OpenAIModal({
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
}
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_key: existingLlmProvider?.api_key ?? "",
default_model_name:
defaultModelName ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,

View File

@@ -158,7 +158,6 @@ export default function OpenRouterModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -190,11 +189,7 @@ export default function OpenRouterModal({
default_model_name: "",
} as OpenRouterModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};

View File

@@ -48,7 +48,6 @@ export default function VertexAIModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
@@ -81,13 +80,8 @@ export default function VertexAIModal({
},
} as VertexAIModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
default_model_name:
defaultModelName ??
wellKnownLLMProvider?.recommended_default_model?.name ??
VERTEXAI_DEFAULT_MODEL,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,

View File

@@ -22,15 +22,9 @@ function detectIfRealOpenAIProvider(provider: LLMProviderView) {
export function getModalForExistingProvider(
provider: LLMProviderView,
open?: boolean,
onOpenChange?: (open: boolean) => void,
defaultModelName?: string
onOpenChange?: (open: boolean) => void
) {
const props = {
existingLlmProvider: provider,
open,
onOpenChange,
defaultModelName,
};
const props = { existingLlmProvider: provider, open, onOpenChange };
switch (provider.provider) {
case LLMProviderName.OPENAI:

View File

@@ -12,11 +12,9 @@ export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
export const buildDefaultInitialValues = (
existingLlmProvider?: LLMProviderView,
modelConfigurations?: ModelConfiguration[],
currentDefaultModelName?: string
modelConfigurations?: ModelConfiguration[]
) => {
const defaultModelName =
currentDefaultModelName ??
existingLlmProvider?.model_configurations?.[0]?.name ??
modelConfigurations?.[0]?.name ??
"";

View File

@@ -182,7 +182,7 @@ const ProjectFolderButton = memo(({ project }: ProjectFolderButtonProps) => {
onClose={() => setIsEditing(false)}
/>
) : (
<Truncated>{project.name}</Truncated>
<Truncated text03>{project.name}</Truncated>
)}
</SidebarTab>
</Popover.Anchor>

View File

@@ -0,0 +1,47 @@
import { test, expect } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.describe(`Chat Input Focus Retention`, () => {
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
await page.goto("/app");
await page.waitForLoadState("networkidle");
});
test("clicking empty space retains focus on chat input", async ({ page }) => {
const textarea = page.locator("#onyx-chat-input-textarea");
await textarea.waitFor({ state: "visible", timeout: 10000 });
// Focus the textarea and type something
await textarea.focus();
await textarea.fill("test message");
await expect(textarea).toBeFocused();
// Click on the main container's empty space (top-left corner)
const container = page.locator("[data-main-container]");
await container.click({ position: { x: 10, y: 10 } });
// Focus should remain on the textarea
await expect(textarea).toBeFocused();
});
test("clicking interactive elements still moves focus away", async ({
page,
}) => {
const textarea = page.locator("#onyx-chat-input-textarea");
await textarea.waitFor({ state: "visible", timeout: 10000 });
// Focus the textarea
await textarea.focus();
await expect(textarea).toBeFocused();
// Click on an interactive element inside the container
const button = page.locator("[data-main-container] button").first();
await button.waitFor({ state: "visible", timeout: 5000 });
await button.click();
// Focus should have moved away from the textarea
await expect(textarea).not.toBeFocused();
});
});