mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-20 23:22:43 +00:00
Compare commits
19 Commits
dane/cache
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
068ac543ad | ||
|
|
30e7a831a5 | ||
|
|
276261c96d | ||
|
|
205f1410e4 | ||
|
|
a93d154c27 | ||
|
|
1361879bd0 | ||
|
|
c58cc320b2 | ||
|
|
461350958a | ||
|
|
50dde0be1a | ||
|
|
199e1df453 | ||
|
|
996b674840 | ||
|
|
5413723ccc | ||
|
|
9660056a51 | ||
|
|
3105177238 | ||
|
|
24bb4bda8b | ||
|
|
9532af4ceb | ||
|
|
0a913f6af5 | ||
|
|
fe30c55199 | ||
|
|
2cf0a65dd3 |
279
AGENTS.md
279
AGENTS.md
@@ -167,284 +167,7 @@ web/
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ RUN apt-get update && \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
# Install procps so kubernetes exec sessions can use ps aux for debugging
|
||||
procps \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if not MULTI_TENANT and HOOK_ENABLED:
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
|
||||
@@ -30,8 +30,6 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import plaintext_file_name_for_id
|
||||
from onyx.file_store.utils import store_plaintext
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
@@ -291,33 +289,6 @@ def process_kg_commands(
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
def _get_or_extract_plaintext(
|
||||
file_id: str,
|
||||
extract_fn: Callable[[], str],
|
||||
) -> str:
|
||||
"""Load cached plaintext for a file, or extract and store it.
|
||||
|
||||
Tries to read pre-stored plaintext from the file store. On a miss,
|
||||
calls extract_fn to produce the text, then stores the result so
|
||||
future calls skip the expensive extraction.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
plaintext_key = plaintext_file_name_for_id(file_id)
|
||||
|
||||
# Try cached plaintext first.
|
||||
try:
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
if content_text:
|
||||
store_plaintext(file_id, content_text)
|
||||
return content_text
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -332,23 +303,12 @@ def load_chat_file(
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
file_id = file_descriptor["id"]
|
||||
|
||||
def _extract() -> str:
|
||||
return extract_file_text(
|
||||
try:
|
||||
content_text = extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
# Use the user_file_id as cache key when available (matches what
|
||||
# the celery indexing worker stores), otherwise fall back to the
|
||||
# file store id (covers code-interpreter-generated files, etc.).
|
||||
user_file_id_str = file_descriptor.get("user_file_id")
|
||||
cache_key = user_file_id_str or file_id
|
||||
|
||||
try:
|
||||
content_text = _get_or_extract_plaintext(cache_key, _extract)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
|
||||
@@ -157,9 +157,7 @@ def _execute_single_retrieval(
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
results = _execute_with_retry(retrieval_function(**request_kwargs))
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
@@ -1062,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
f"Body: {get_new_body_without_vectors(body)}\n"
|
||||
f"Search pipeline ID: {search_pipeline_id}\n"
|
||||
f"Phase took: {phase_took}\n"
|
||||
f"Profile: {profile}\n"
|
||||
f"Profile: {json.dumps(profile, indent=2)}\n"
|
||||
)
|
||||
if timed_out:
|
||||
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
|
||||
|
||||
@@ -950,7 +950,86 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
# "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_embedding,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
|
||||
@@ -404,12 +404,170 @@ class DocumentQuery:
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# Explain is for scoring breakdowns.
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
def get_keyword_search_query(
|
||||
query_text: str,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final keyword search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the keyword search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final keyword search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
keyword_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
keyword_search_query = (
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text, search_filters=keyword_search_filters
|
||||
)
|
||||
)
|
||||
|
||||
final_keyword_search_query: dict[str, Any] = {
|
||||
"query": keyword_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_keyword_search_query["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_keyword_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_keyword_search_query["explain"] = True
|
||||
|
||||
return final_keyword_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_semantic_search_query(
|
||||
query_embedding: list[float],
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final semantic search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the semantic search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final semantic search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
semantic_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
semantic_search_query = (
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_embedding,
|
||||
vector_candidates=num_hits,
|
||||
search_filters=semantic_search_filters,
|
||||
)
|
||||
)
|
||||
|
||||
final_semantic_search_query: dict[str, Any] = {
|
||||
"query": semantic_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_semantic_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_semantic_search_query["explain"] = True
|
||||
|
||||
return final_semantic_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_random_search_query(
|
||||
tenant_state: TenantState,
|
||||
@@ -581,8 +739,9 @@ class DocumentQuery:
|
||||
def _get_content_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
query = {
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
@@ -591,11 +750,19 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
|
||||
"bool": {"filter": search_filters}
|
||||
}
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_title_content_combined_keyword_search_query(
|
||||
query_text: str,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
@@ -636,10 +803,19 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
],
|
||||
# Ensure at least one term from the query is present in the
|
||||
# document. This defaults to 1, unless a filter or must clause
|
||||
# is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["bool"]["filter"] = search_filters
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
|
||||
@@ -88,6 +88,7 @@ class OnyxErrorCode(Enum):
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
|
||||
@@ -23,55 +23,45 @@ from onyx.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def plaintext_file_name_for_id(file_id: str) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a file."""
|
||||
return f"plaintext_{file_id}"
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
|
||||
|
||||
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""
|
||||
Store plaintext content for a file in the file store.
|
||||
Store plaintext content for a user file in the file store.
|
||||
|
||||
Args:
|
||||
file_id: The ID of the file (user_file or artifact_file)
|
||||
user_file_id: The ID of the user file
|
||||
plaintext_content: The plaintext content to store
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
plaintext_file_name = plaintext_file_name_for_id(file_id)
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for {file_id}",
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Convenience wrappers for callers that use user-file UUIDs ---
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return plaintext_file_name_for_id(str(user_file_id))
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
|
||||
return store_plaintext(str(user_file_id), plaintext_content)
|
||||
|
||||
|
||||
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
|
||||
"""Load a file directly from the file store using its file_record ID.
|
||||
|
||||
|
||||
330
backend/onyx/hooks/executor.py
Normal file
330
backend/onyx/hooks/executor.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
"""No active hook configured for this hook point."""
|
||||
|
||||
|
||||
class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
)
|
||||
return outcome.response_payload
|
||||
@@ -42,12 +42,8 @@ class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
fail_strategy: HookFailStrategy | None = None
|
||||
timeout_seconds: float | None = Field(default=None, gt=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
@@ -60,6 +56,14 @@ class HookUpdateRequest(BaseModel):
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
|
||||
raise ValueError(
|
||||
"fail_strategy cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
|
||||
raise ValueError(
|
||||
"timeout_seconds cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -90,38 +94,28 @@ class HookResponse(BaseModel):
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
is_reachable: bool | None
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateStatus(str, Enum):
|
||||
passed = "passed" # server responded (any status except 401/403)
|
||||
auth_failed = "auth_failed" # server responded with 401 or 403
|
||||
timeout = (
|
||||
"timeout" # TCP connected, but read/write timed out (server exists but slow)
|
||||
)
|
||||
cannot_connect = "cannot_connect" # could not connect to the server
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
success: bool
|
||||
status: HookValidateStatus
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
class HookExecutionRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -13,22 +14,25 @@ _REQUIRED_ATTRS = (
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
"payload_model",
|
||||
"response_model",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec(ABC):
|
||||
class HookPointSpec:
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
This is NOT a regular class meant for direct instantiation by callers.
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
|
||||
should ever create instances directly — use get_hook_point_spec() or
|
||||
get_all_specs() from the registry instead.
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
|
||||
get_hook_point_spec() or get_all_specs() from the registry over direct
|
||||
instantiation.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
payload_model and response_model must be Pydantic BaseModel subclasses;
|
||||
input_schema and output_schema are derived from them automatically.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
@@ -39,21 +43,33 @@ class HookPointSpec(ABC):
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
payload_model: ClassVar[type[BaseModel]]
|
||||
response_model: ClassVar[type[BaseModel]]
|
||||
|
||||
# Computed once at class definition time from payload_model / response_model.
|
||||
input_schema: ClassVar[dict[str, Any]]
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
for attr in ("payload_model", "response_model"):
|
||||
val = getattr(cls, attr, None)
|
||||
if val is None or not (
|
||||
isinstance(val, type) and issubclass(val, BaseModel)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
|
||||
)
|
||||
cls.input_schema = cls.payload_model.model_json_schema()
|
||||
cls.output_schema = cls.response_model.model_json_schema()
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
@@ -18,12 +27,5 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -1,10 +1,39 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(description="The raw query string exactly as the user typed it.")
|
||||
user_email: str | None = Field(
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
query: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
@@ -37,47 +66,5 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
"chat_session_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email", "chat_session_id"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
5
backend/onyx/hooks/utils.py
Normal file
5
backend/onyx/hooks/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -530,6 +530,11 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
optional_kwargs["tool_choice"] = tool_choice
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
@@ -538,7 +543,6 @@ class LitellmLLM(LLM):
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
|
||||
@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -453,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
0
backend/onyx/server/features/hooks/__init__.py
Normal file
0
backend/onyx/server/features/hooks/__init__.py
Normal file
453
backend/onyx/server/features/hooks/api.py
Normal file
453
backend/onyx/server/features/hooks/api.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
from onyx.db.hook import get_hook_execution_logs
|
||||
from onyx.db.hook import get_hooks
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookExecutionRecord
|
||||
from onyx.hooks.models import HookPointMetaResponse
|
||||
from onyx.hooks.models import HookResponse
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
|
||||
return HookResponse(
|
||||
id=hook.id,
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
is_reachable=hook.is_reachable,
|
||||
creator_email=(
|
||||
creator_email
|
||||
if creator_email is not None
|
||||
else (hook.creator.email if hook.creator else None)
|
||||
),
|
||||
created_at=hook.created_at,
|
||||
updated_at=hook.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_hook_or_404(
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
include_creator=include_creator,
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
|
||||
return hook
|
||||
|
||||
|
||||
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
|
||||
"""Raise an appropriate OnyxError for a non-passed validation result."""
|
||||
if validation.status == HookValidateStatus.auth_failed:
|
||||
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
|
||||
if validation.status == HookValidateStatus.timeout:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_endpoint(
|
||||
endpoint_url: str,
|
||||
api_key: str | None,
|
||||
timeout_seconds: float,
|
||||
) -> HookValidateResponse:
|
||||
"""Check whether endpoint_url is reachable by sending an empty POST request.
|
||||
|
||||
We use POST since hook endpoints expect POST requests. The server will typically
|
||||
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
|
||||
the server is up and routable. A 401/403 response returns auth_failed
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
|
||||
response = client.post(endpoint_url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed,
|
||||
error_message=f"Authentication failed (HTTP {response.status_code})",
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.timeout,
|
||||
error_message="Endpoint timed out — consider increasing timeout_seconds.",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connection error for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
HookPointMetaResponse(
|
||||
hook_point=spec.hook_point,
|
||||
display_name=spec.display_name,
|
||||
description=spec.description,
|
||||
docs_url=spec.docs_url,
|
||||
input_schema=spec.input_schema,
|
||||
output_schema=spec.output_schema,
|
||||
default_timeout_seconds=spec.default_timeout_seconds,
|
||||
default_fail_strategy=spec.default_fail_strategy,
|
||||
fail_hard_description=spec.fail_hard_description,
|
||||
)
|
||||
for spec in get_all_specs()
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
hooks = get_hooks(db_session=db_session, include_creator=True)
|
||||
return [_hook_to_response(h) for h in hooks]
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = create_hook__no_commit(
|
||||
db_session=db_session,
|
||||
name=req.name,
|
||||
hook_point=req.hook_point,
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.patch("/{hook_id}")
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
|
||||
endpoint is re-validated using the effective values. For active hooks the update is
|
||||
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
|
||||
the update goes through regardless and is_reachable is updated to reflect the result.
|
||||
|
||||
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
|
||||
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
|
||||
"""
|
||||
# api_key: UNSET = no change, None = clear, value = update
|
||||
api_key: str | None | UnsetType
|
||||
if "api_key" not in req.model_fields_set:
|
||||
api_key = UNSET
|
||||
elif req.api_key is None:
|
||||
api_key = None
|
||||
else:
|
||||
api_key = req.api_key.get_secret_value()
|
||||
|
||||
endpoint_url_changing = "endpoint_url" in req.model_fields_set
|
||||
api_key_changing = not isinstance(api_key, UnsetType)
|
||||
timeout_changing = "timeout_seconds" in req.model_fields_set
|
||||
|
||||
validated_is_reachable: bool | None = None
|
||||
if endpoint_url_changing or api_key_changing or timeout_changing:
|
||||
existing = _get_hook_or_404(db_session, hook_id)
|
||||
effective_url: str = (
|
||||
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
|
||||
)
|
||||
effective_api_key: str | None = (
|
||||
(api_key if not isinstance(api_key, UnsetType) else None)
|
||||
if api_key_changing
|
||||
else (
|
||||
existing.api_key.get_value(apply_mask=False)
|
||||
if existing.api_key
|
||||
else None
|
||||
)
|
||||
)
|
||||
effective_timeout: float = (
|
||||
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
)
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=effective_url,
|
||||
api_key=effective_api_key,
|
||||
timeout_seconds=effective_timeout,
|
||||
)
|
||||
if existing.is_active and validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
validated_is_reachable = validation.status == HookValidateStatus.passed
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
name=req.name,
|
||||
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds,
|
||||
is_reachable=validated_is_reachable,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
# Persist is_reachable=False in a separate session so the request
|
||||
# session has no commits on the failure path and the transaction
|
||||
# boundary stays clean.
|
||||
if hook.is_reachable is not False:
|
||||
with get_session_with_current_tenant() as side_session:
|
||||
update_hook__no_commit(
|
||||
db_session=side_session, hook_id=hook_id, is_reachable=False
|
||||
)
|
||||
side_session.commit()
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
validation_passed = validation.status == HookValidateStatus.passed
|
||||
if hook.is_reachable != validation_passed:
|
||||
update_hook__no_commit(
|
||||
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
|
||||
)
|
||||
db_session.commit()
|
||||
return validation
|
||||
|
||||
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=False,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution log endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{hook_id}/execution-logs")
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
_get_hook_or_404(db_session, hook_id)
|
||||
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
|
||||
return [
|
||||
HookExecutionRecord(
|
||||
error_message=log.error_message,
|
||||
status_code=log.status_code,
|
||||
duration_ms=log.duration_ms,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import mimetypes
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
@@ -83,6 +84,14 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
def __init__(self, tool_id: int, emitter: Emitter) -> None:
|
||||
super().__init__(emitter=emitter)
|
||||
self._id = tool_id
|
||||
# Cache of (filename, content_hash) -> ci_file_id to avoid re-uploading
|
||||
# the same file on every tool call iteration within the same agent session.
|
||||
# Filename is included in the key so two files with identical bytes but
|
||||
# different names each get their own upload slot.
|
||||
# TTL assumption: code-interpreter file TTLs (typically hours) greatly
|
||||
# exceed the lifetime of a single agent session (at most MAX_LLM_CYCLES
|
||||
# iterations, typically a few minutes), so stale-ID eviction is not needed.
|
||||
self._uploaded_file_cache: dict[tuple[str, str], str] = {}
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -182,8 +191,13 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
content_hash = hashlib.sha256(chat_file.content).hexdigest()
|
||||
cache_key = (file_name, content_hash)
|
||||
ci_file_id = self._uploaded_file_cache.get(cache_key)
|
||||
if ci_file_id is None:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
self._uploaded_file_cache[cache_key] = ci_file_id
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
@@ -299,14 +313,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
# Note: staged input files are intentionally not deleted here because
|
||||
# _uploaded_file_cache reuses their file_ids across iterations. They are
|
||||
# orphaned when the session ends, but the code interpreter cleans up
|
||||
# stale files on its own TTL.
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
|
||||
@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
def validate_outbound_http_url(
|
||||
url: str,
|
||||
*,
|
||||
allow_private_network: bool = False,
|
||||
https_only: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
allow_private_network: If True, skip private/reserved IP checks.
|
||||
https_only: If True, reject http:// URLs (only https:// is allowed).
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
if https_only:
|
||||
if parsed.scheme != "https":
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
|
||||
)
|
||||
elif parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
@@ -1640,3 +1640,275 @@ class TestOpenSearchClient:
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
|
||||
def test_keyword_search(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests keyword search with filters for ACL, hidden documents, and tenant
|
||||
isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
# Tests that we don't return documents that don't match keywords at
|
||||
# all, even if they match filters.
|
||||
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-but-not-relevant-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="This text should not match the query at all",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Should not match private-but-not-relevant-doc-user-a.
|
||||
query_text = "document content"
|
||||
search_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query_text,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# This should be the highest-ranked result, as a higher percentage of
|
||||
# the content matches the query.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
assert results[1].score < results[0].score
|
||||
|
||||
def test_semantic_search(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests semantic search with filters for ACL, hidden documents, and tenant
|
||||
isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
# Make this identical to the query vector to test that this
|
||||
# result is returned first.
|
||||
content_vector=_generate_test_vector(0.6),
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
# Make this different from the query vector to test that this
|
||||
# result is returned second.
|
||||
content_vector=_generate_test_vector(0.5),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
query_vector = _generate_test_vector(0.6)
|
||||
search_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_vector,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# We explicitly expect this to be the highest-ranked result.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[0].score == 1.0
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert 0.0 < results[1].score < 1.0
|
||||
|
||||
@@ -1219,15 +1219,16 @@ def test_code_interpreter_receives_chat_files(
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
|
||||
# Verify: file uploaded, code executed via streaming, staged file cleaned up
|
||||
# Verify: file uploaded and code executed via streaming.
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
|
||||
delete_requests = mock_ci_server.get_requests(method="DELETE")
|
||||
assert len(delete_requests) == 1
|
||||
assert delete_requests[0].path.startswith("/v1/files/")
|
||||
# Staged input files are intentionally NOT deleted — PythonTool caches their
|
||||
# file IDs across agent-loop iterations to avoid re-uploading on every call.
|
||||
# The code interpreter cleans them up via its own TTL.
|
||||
assert len(mock_ci_server.get_requests(method="DELETE")) == 0
|
||||
|
||||
execute_body = mock_ci_server.get_requests(
|
||||
method="POST", path="/v1/execute/stream"
|
||||
|
||||
@@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
@@ -28,6 +29,9 @@ _BACKEND_DIR = os.path.normpath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
|
||||
)
|
||||
|
||||
_DROP_SCHEMA_MAX_RETRIES = 3
|
||||
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -50,6 +54,39 @@ def _run_script(
|
||||
)
|
||||
|
||||
|
||||
def _force_drop_schema(engine: Engine, schema: str) -> None:
|
||||
"""Terminate backends using *schema* then drop it, retrying on deadlock.
|
||||
|
||||
Background Celery workers may discover test schemas (they match the
|
||||
``tenant_`` prefix) and hold locks on tables inside them. A bare
|
||||
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
|
||||
first kill their connections and retry if we still hit a deadlock.
|
||||
"""
|
||||
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT pg_terminate_backend(l.pid)
|
||||
FROM pg_locks l
|
||||
JOIN pg_class c ON c.oid = l.relation
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = :schema
|
||||
AND l.pid != pg_backend_pid()
|
||||
"""
|
||||
),
|
||||
{"schema": schema},
|
||||
)
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
return
|
||||
except Exception:
|
||||
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
|
||||
raise
|
||||
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -104,9 +141,7 @@ def tenant_schema_at_head(
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
_force_drop_schema(engine, schema)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -123,9 +158,7 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
_force_drop_schema(engine, schema)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -150,9 +183,7 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
_force_drop_schema(engine, schema)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
@@ -11,12 +10,10 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
|
||||
|
||||
class IncompleteSpec(HookPointSpec):
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
# missing display_name, description, etc.
|
||||
# missing display_name, description, payload_model, response_model, etc.
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
class _Payload(BaseModel):
|
||||
pass
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
payload_model = _Payload
|
||||
response_model = _Payload
|
||||
|
||||
541
backend/tests/unit/onyx/hooks/test_executor.py
Normal file
541
backend/tests/unit/onyx/hooks/test_executor.py
Normal file
@@ -0,0 +1,541 @@
|
||||
"""Unit tests for the hook executor."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
|
||||
|
||||
|
||||
def _make_hook(
|
||||
*,
|
||||
is_active: bool = True,
|
||||
endpoint_url: str | None = "https://hook.example.com/query",
|
||||
api_key: MagicMock | None = None,
|
||||
timeout_seconds: float = 5.0,
|
||||
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
|
||||
hook_id: int = 1,
|
||||
is_reachable: bool | None = None,
|
||||
) -> MagicMock:
|
||||
hook = MagicMock()
|
||||
hook.is_active = is_active
|
||||
hook.endpoint_url = endpoint_url
|
||||
hook.api_key = api_key
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
hook.id = hook_id
|
||||
hook.fail_strategy = fail_strategy
|
||||
hook.is_reachable = is_reachable
|
||||
return hook
|
||||
|
||||
|
||||
def _make_api_key(value: str) -> MagicMock:
|
||||
api_key = MagicMock()
|
||||
api_key.get_value.return_value = value
|
||||
return api_key
|
||||
|
||||
|
||||
def _make_response(
|
||||
*,
|
||||
status_code: int = 200,
|
||||
json_return: Any = _RESPONSE_PAYLOAD,
|
||||
json_side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a response mock with controllable json() behaviour."""
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
if json_side_effect is not None:
|
||||
response.json.side_effect = json_side_effect
|
||||
else:
|
||||
response.json.return_value = json_return
|
||||
return response
|
||||
|
||||
|
||||
def _setup_client(
|
||||
mock_client_cls: MagicMock,
|
||||
*,
|
||||
response: MagicMock | None = None,
|
||||
side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Wire up the httpx.Client mock and return the inner client.
|
||||
|
||||
If side_effect is an httpx.HTTPStatusError, it is raised from
|
||||
raise_for_status() (matching real httpx behaviour) and post() returns a
|
||||
response mock with the matching status_code set. All other exceptions are
|
||||
raised directly from post().
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
if isinstance(side_effect, httpx.HTTPStatusError):
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = side_effect.response.status_code
|
||||
error_response.raise_for_status.side_effect = side_effect
|
||||
mock_client.post = MagicMock(return_value=error_response)
|
||||
else:
|
||||
mock_client.post = MagicMock(
|
||||
side_effect=side_effect, return_value=response if not side_effect else None
|
||||
)
|
||||
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Early-exit guards (no HTTP call, no DB writes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hooks_available,hook",
|
||||
[
|
||||
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(False, None, id="hooks_not_available"),
|
||||
pytest.param(True, None, id="hook_not_found"),
|
||||
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
],
|
||||
)
|
||||
def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session: MagicMock,
|
||||
hooks_available: bool,
|
||||
hook: MagicMock | None,
|
||||
) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSkipped)
|
||||
mock_update.assert_not_called()
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Successful HTTP call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
_, update_kwargs = mock_update.call_args
|
||||
assert update_kwargs["is_reachable"] is True
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
|
||||
"""Deduplication guard: a hook already at is_reachable=True that succeeds
|
||||
must not trigger a DB write."""
|
||||
hook = _make_hook(is_reachable=True)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(json_return=["unexpected", "list"]),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-dict" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() raising must be treated as failure with SOFT strategy.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(
|
||||
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
|
||||
),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-JSON" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP failure paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception,fail_strategy,expected_type,expected_is_reachable",
|
||||
[
|
||||
# NetworkError → is_reachable=False
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="connect_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="connect_error_hard",
|
||||
),
|
||||
# 401/403 → is_reachable=False (api_key revoked)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"401",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=401, text="Unauthorized"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="auth_401_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"403",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=403, text="Forbidden"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="auth_403_hard",
|
||||
),
|
||||
# TimeoutException → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="timeout_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="timeout_hard",
|
||||
),
|
||||
# Other HTTP errors → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="http_status_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="http_status_error_hard",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_http_failure_paths(
|
||||
db_session: MagicMock,
|
||||
exception: Exception,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
expected_is_reachable: bool | None,
|
||||
) -> None:
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=exception)
|
||||
|
||||
if expected_type is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert isinstance(result, expected_type)
|
||||
|
||||
if expected_is_reachable is None:
|
||||
mock_update.assert_not_called()
|
||||
else:
|
||||
mock_update.assert_called_once()
|
||||
_, kwargs = mock_update.call_args
|
||||
assert kwargs["is_reachable"] is expected_is_reachable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization header
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_key_value,expect_auth_header",
|
||||
[
|
||||
pytest.param("secret-token", True, id="api_key_present"),
|
||||
pytest.param(None, False, id="api_key_absent"),
|
||||
],
|
||||
)
|
||||
def test_authorization_header(
|
||||
db_session: MagicMock,
|
||||
api_key_value: str | None,
|
||||
expect_auth_header: bool,
|
||||
) -> None:
|
||||
api_key = _make_api_key(api_key_value) if api_key_value else None
|
||||
hook = _make_hook(api_key=api_key)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
mock_client = _setup_client(mock_client_cls, response=_make_response())
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
_, call_kwargs = mock_client.post.call_args
|
||||
if expect_auth_header:
|
||||
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
|
||||
else:
|
||||
assert "Authorization" not in call_kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist session failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"http_exception,expected_result",
|
||||
[
|
||||
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
|
||||
],
|
||||
)
|
||||
def test_persist_session_failure_is_swallowed(
|
||||
db_session: MagicMock,
|
||||
http_exception: Exception | None,
|
||||
expected_result: Any,
|
||||
) -> None:
|
||||
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_session_with_current_tenant",
|
||||
side_effect=RuntimeError("DB unavailable"),
|
||||
),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response() if not http_exception else None,
|
||||
side_effect=http_exception,
|
||||
)
|
||||
|
||||
if expected_result is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
|
||||
"""is_reachable update failing (e.g. concurrent hook deletion) must not
|
||||
prevent the execution log from being written.
|
||||
|
||||
Simulates the production failure path: update_hook__no_commit raises
|
||||
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
|
||||
between the initial lookup and the reachable update.
|
||||
"""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.hooks.executor.update_hook__no_commit",
|
||||
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
|
||||
),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
mock_log.assert_called_once()
|
||||
@@ -37,18 +37,20 @@ def test_input_schema_query_is_string() -> None:
|
||||
|
||||
def test_input_schema_user_email_is_nullable() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert "null" in props["user_email"]["type"]
|
||||
# Pydantic v2 emits anyOf for nullable fields
|
||||
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
|
||||
|
||||
|
||||
def test_output_schema_query_is_required() -> None:
|
||||
def test_output_schema_query_is_optional() -> None:
|
||||
# query defaults to None (absent = reject); not required in the schema
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "query" in schema["required"]
|
||||
assert "query" not in schema.get("required", [])
|
||||
|
||||
|
||||
def test_output_schema_query_is_nullable() -> None:
|
||||
# null means "reject the query"
|
||||
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
|
||||
props = QueryProcessingSpec().output_schema["properties"]
|
||||
assert "null" in props["query"]["type"]
|
||||
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
|
||||
|
||||
|
||||
def test_output_schema_rejection_message_is_optional() -> None:
|
||||
|
||||
@@ -256,7 +256,6 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
|
||||
{"role": "user", "content": "What's the weather and time in New York?"}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
@@ -412,7 +411,6 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
|
||||
{"role": "user", "content": "What's the weather and time in New York?"}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
@@ -1431,3 +1429,36 @@ def test_strip_tool_content_merges_consecutive_tool_results() -> None:
|
||||
assert "sunny 72F" in merged
|
||||
assert "tc_2" in merged
|
||||
assert "headline news" in merged
|
||||
|
||||
|
||||
def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> None:
|
||||
"""Regression test for providers (e.g. Fireworks) that reject tool_choice=null.
|
||||
|
||||
When no tools are provided, tool_choice must not be forwarded to
|
||||
litellm.completion() at all — not even as None.
|
||||
"""
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello!")]
|
||||
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
default_multi_llm.invoke(messages, tools=None)
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert (
|
||||
"tool_choice" not in kwargs
|
||||
), "tool_choice must not be sent to providers when no tools are provided"
|
||||
|
||||
278
backend/tests/unit/onyx/server/features/hooks/test_api.py
Normal file
278
backend/tests/unit/onyx/server/features/hooks/test_api.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Unit tests for onyx.server.features.hooks.api helpers.
|
||||
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from onyx.server.features.hooks.api import _validate_endpoint
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_URL = "https://example.com/hook"
|
||||
_API_KEY = "secret"
|
||||
_TIMEOUT = 5.0
|
||||
|
||||
|
||||
def _mock_response(status_code: int) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_ssrf_safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSsrfSafety:
|
||||
def _call(self, url: str) -> None:
|
||||
_check_ssrf_safety(url)
|
||||
|
||||
# --- scheme checks ---
|
||||
|
||||
def test_https_is_allowed(self) -> None:
|
||||
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
|
||||
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
|
||||
self._call("https://example.com/hook") # must not raise
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url", ["http://example.com/hook", "ftp://example.com/hook"]
|
||||
)
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
pytest.param("127.0.0.1", id="loopback"),
|
||||
pytest.param("10.0.0.1", id="RFC1918-A"),
|
||||
pytest.param("172.16.0.1", id="RFC1918-B"),
|
||||
pytest.param("192.168.1.1", id="RFC1918-C"),
|
||||
pytest.param("169.254.169.254", id="link-local-IMDS"),
|
||||
pytest.param("100.64.0.1", id="shared-address-space"),
|
||||
pytest.param("::1", id="IPv6-loopback"),
|
||||
pytest.param("fc00::1", id="IPv6-ULA"),
|
||||
pytest.param("fe80::1", id="IPv6-link-local"),
|
||||
],
|
||||
)
|
||||
def test_private_ip_is_blocked(self, ip: str) -> None:
|
||||
with (
|
||||
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
|
||||
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
|
||||
self._call("https://example.com/hook") # must not raise
|
||||
|
||||
def test_dns_resolution_failure_raises(self) -> None:
|
||||
import socket
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.utils.url.socket.getaddrinfo",
|
||||
side_effect=socket.gaierror("name not found"),
|
||||
),
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEndpoint:
|
||||
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
|
||||
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
|
||||
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
return _validate_endpoint(
|
||||
endpoint_url=_URL,
|
||||
api_key=api_key,
|
||||
timeout_seconds=_TIMEOUT,
|
||||
)
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(200)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(500)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
def test_401_403_returns_auth_failed(
|
||||
self, mock_client_cls: MagicMock, status_code: int
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(status_code)
|
||||
)
|
||||
result = self._call()
|
||||
assert result.status == HookValidateStatus.auth_failed
|
||||
assert str(status_code) in (result.error_message or "")
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(422)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
httpx.ReadTimeout("read timeout"),
|
||||
httpx.WriteTimeout("write timeout"),
|
||||
httpx.PoolTimeout("pool timeout"),
|
||||
],
|
||||
)
|
||||
def test_read_write_pool_timeout_returns_timeout(
|
||||
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_error_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
# Covers DNS failures, TLS errors, and other connection-level errors.
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectError("name resolution failed")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_arbitrary_exception_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
ConnectionRefusedError("refused")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
self._call(api_key="mykey")
|
||||
_, kwargs = mock_post.call_args
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
self._call(api_key=None)
|
||||
_, kwargs = mock_post.call_args
|
||||
assert "Authorization" not in kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _raise_for_validation_failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRaiseForValidationFailure:
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected_code",
|
||||
[
|
||||
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
|
||||
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
|
||||
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
|
||||
],
|
||||
)
|
||||
def test_raises_correct_error_code(
|
||||
self, status: HookValidateStatus, expected_code: OnyxErrorCode
|
||||
) -> None:
|
||||
validation = HookValidateResponse(status=status, error_message="some error")
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.error_code == expected_code
|
||||
|
||||
def test_auth_failed_passes_error_message_directly(self) -> None:
|
||||
validation = HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed, error_message="bad credentials"
|
||||
)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.detail == "bad credentials"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
|
||||
)
|
||||
def test_timeout_and_cannot_connect_wrap_error_message(
|
||||
self, status: HookValidateStatus
|
||||
) -> None:
|
||||
validation = HookValidateResponse(status=status, error_message="raw error")
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.detail == "Endpoint validation failed: raw error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HookValidateStatus enum string values (API contract)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHookValidateStatusValues:
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected",
|
||||
[
|
||||
(HookValidateStatus.passed, "passed"),
|
||||
(HookValidateStatus.auth_failed, "auth_failed"),
|
||||
(HookValidateStatus.timeout, "timeout"),
|
||||
(HookValidateStatus.cannot_connect, "cannot_connect"),
|
||||
],
|
||||
)
|
||||
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
|
||||
assert status == expected
|
||||
@@ -0,0 +1,208 @@
|
||||
"""Unit tests for PythonTool file-upload caching.
|
||||
|
||||
Verifies that PythonTool reuses code-interpreter file IDs across multiple
|
||||
run() calls within the same session instead of re-uploading identical content
|
||||
on every agent loop iteration.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamResultEvent,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
|
||||
|
||||
|
||||
def _make_stream_result() -> StreamResultEvent:
|
||||
return StreamResultEvent(
|
||||
exit_code=0,
|
||||
timed_out=False,
|
||||
duration_ms=10,
|
||||
files=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_tool() -> PythonTool:
|
||||
emitter = MagicMock()
|
||||
return PythonTool(tool_id=1, emitter=emitter)
|
||||
|
||||
|
||||
def _make_override(files: list[ChatFile]) -> PythonToolOverrideKwargs:
|
||||
return PythonToolOverrideKwargs(chat_files=files)
|
||||
|
||||
|
||||
def _run_tool(tool: PythonTool, mock_client: MagicMock, files: list[ChatFile]) -> None:
|
||||
"""Call tool.run() with a mocked CodeInterpreterClient context manager."""
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
mock_client.execute_streaming.return_value = iter([_make_stream_result()])
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__ = MagicMock(return_value=mock_client)
|
||||
ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
placement = Placement(turn_index=0, tab_index=0)
|
||||
override = _make_override(files)
|
||||
|
||||
with patch(f"{TOOL_MODULE}.CodeInterpreterClient", return_value=ctx):
|
||||
tool.run(placement=placement, override_kwargs=override, code="print('hi')")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache hit: same content uploaded in a second call reuses the file_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_same_file_uploaded_only_once_across_two_runs() -> None:
|
||||
tool = _make_tool()
|
||||
client = MagicMock()
|
||||
client.upload_file.return_value = "file-id-abc"
|
||||
|
||||
pptx_content = b"fake pptx bytes"
|
||||
files = [ChatFile(filename="report.pptx", content=pptx_content)]
|
||||
|
||||
_run_tool(tool, client, files)
|
||||
_run_tool(tool, client, files)
|
||||
|
||||
# upload_file should only have been called once across both runs
|
||||
client.upload_file.assert_called_once_with(pptx_content, "report.pptx")
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_cached_file_id_is_staged_on_second_run() -> None:
|
||||
tool = _make_tool()
|
||||
client = MagicMock()
|
||||
client.upload_file.return_value = "file-id-abc"
|
||||
|
||||
files = [ChatFile(filename="data.pptx", content=b"content")]
|
||||
|
||||
_run_tool(tool, client, files)
|
||||
|
||||
# On the second run, execute_streaming should still receive the file
|
||||
client.execute_streaming.return_value = iter([_make_stream_result()])
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__ = MagicMock(return_value=client)
|
||||
ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
placement = Placement(turn_index=1, tab_index=0)
|
||||
with patch(f"{TOOL_MODULE}.CodeInterpreterClient", return_value=ctx):
|
||||
tool.run(
|
||||
placement=placement,
|
||||
override_kwargs=_make_override(files),
|
||||
code="print('hi')",
|
||||
)
|
||||
|
||||
# The second execute_streaming call should include the file
|
||||
_, kwargs = client.execute_streaming.call_args
|
||||
staged_files = kwargs.get("files") or []
|
||||
assert any(f["file_id"] == "file-id-abc" for f in staged_files)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache miss: different content triggers a new upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_different_file_content_uploaded_separately() -> None:
|
||||
tool = _make_tool()
|
||||
client = MagicMock()
|
||||
client.upload_file.side_effect = ["file-id-v1", "file-id-v2"]
|
||||
|
||||
file_v1 = ChatFile(filename="report.pptx", content=b"version 1")
|
||||
file_v2 = ChatFile(filename="report.pptx", content=b"version 2")
|
||||
|
||||
_run_tool(tool, client, [file_v1])
|
||||
_run_tool(tool, client, [file_v2])
|
||||
|
||||
assert client.upload_file.call_count == 2
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_multiple_distinct_files_each_uploaded_once() -> None:
|
||||
tool = _make_tool()
|
||||
client = MagicMock()
|
||||
client.upload_file.side_effect = ["id-a", "id-b"]
|
||||
|
||||
files = [
|
||||
ChatFile(filename="a.pptx", content=b"aaa"),
|
||||
ChatFile(filename="b.xlsx", content=b"bbb"),
|
||||
]
|
||||
|
||||
_run_tool(tool, client, files)
|
||||
_run_tool(tool, client, files)
|
||||
|
||||
# Two distinct files — each uploaded exactly once
|
||||
assert client.upload_file.call_count == 2
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_same_content_different_filename_uploaded_separately() -> None:
|
||||
# Identical bytes but different names must each get their own upload slot
|
||||
# so both files appear under their respective paths in the workspace.
|
||||
tool = _make_tool()
|
||||
client = MagicMock()
|
||||
client.upload_file.side_effect = ["id-v1", "id-v2"]
|
||||
|
||||
same_bytes = b"shared content"
|
||||
files = [
|
||||
ChatFile(filename="report_v1.csv", content=same_bytes),
|
||||
ChatFile(filename="report_v2.csv", content=same_bytes),
|
||||
]
|
||||
|
||||
_run_tool(tool, client, files)
|
||||
|
||||
assert client.upload_file.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# No cross-instance sharing: a fresh PythonTool re-uploads everything
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_new_tool_instance_re_uploads_file() -> None:
|
||||
client = MagicMock()
|
||||
client.upload_file.side_effect = ["id-session-1", "id-session-2"]
|
||||
|
||||
files = [ChatFile(filename="deck.pptx", content=b"slide data")]
|
||||
|
||||
tool_session_1 = _make_tool()
|
||||
_run_tool(tool_session_1, client, files)
|
||||
|
||||
tool_session_2 = _make_tool()
|
||||
_run_tool(tool_session_2, client, files)
|
||||
|
||||
# Different instances — each uploads independently
|
||||
assert client.upload_file.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload failure: failed upload is not cached, retried next run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
|
||||
def test_upload_failure_not_cached() -> None:
|
||||
tool = _make_tool()
|
||||
client = MagicMock()
|
||||
# First call raises, second succeeds
|
||||
client.upload_file.side_effect = [Exception("network error"), "file-id-ok"]
|
||||
|
||||
files = [ChatFile(filename="slides.pptx", content=b"data")]
|
||||
|
||||
# First run — upload fails, file is skipped but not cached
|
||||
_run_tool(tool, client, files)
|
||||
|
||||
# Second run — should attempt upload again
|
||||
_run_tool(tool, client, files)
|
||||
|
||||
assert client.upload_file.call_count == 2
|
||||
93
cubic.yaml
Normal file
93
cubic.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
# yaml-language-server: $schema=https://cubic.dev/schema/cubic-repository-config.schema.json
|
||||
version: 1
|
||||
|
||||
reviews:
|
||||
enabled: true
|
||||
sensitivity: medium
|
||||
incremental_commits: true
|
||||
check_drafts: false
|
||||
|
||||
custom_instructions: |
|
||||
Use explicit type annotations for variables to enhance code clarity,
|
||||
especially when moving type hints around in the code.
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context.
|
||||
Prefer consistency with existing patterns, fix issues in code you touch,
|
||||
avoid tacking new features onto muddy interfaces, fail loudly instead of
|
||||
silently swallowing errors, keep code strictly typed, preserve clear state
|
||||
boundaries, remove duplicate or dead logic, break up overly long functions,
|
||||
avoid hidden import-time side effects, respect module boundaries, and favor
|
||||
correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
Reference these files for additional context:
|
||||
- `contributing_guides/best_practices.md` — Best practices for contributing to the codebase
|
||||
- `CLAUDE.md` — Project instructions and coding standards
|
||||
- `backend/alembic/README.md` — Migration guidance, including multi-tenant migration behavior
|
||||
- `deployment/helm/charts/onyx/values-lite.yaml` — Lite deployment Helm values and service assumptions
|
||||
- `deployment/docker_compose/docker-compose.onyx-lite.yml` — Lite deployment Docker Compose overlay and disabled service behavior
|
||||
|
||||
ignore:
|
||||
files:
|
||||
- greptile.json
|
||||
- cubic.yaml
|
||||
|
||||
custom_rules:
|
||||
- name: TODO format
|
||||
description: >
|
||||
Whenever a TODO is added, there must always be an associated name or
|
||||
ticket in the style of TODO(name): ... or TODO(1234): ...
|
||||
|
||||
- name: Frontend standards
|
||||
description: >
|
||||
For frontend changes, enforce all standards described in the
|
||||
web/AGENTS.md file.
|
||||
include:
|
||||
- web/**
|
||||
- desktop/**
|
||||
|
||||
- name: No debugging code
|
||||
description: >
|
||||
Remove temporary debugging code before merging to production,
|
||||
especially tenant-specific debugging logs.
|
||||
|
||||
- name: No hardcoded booleans
|
||||
description: >
|
||||
When hardcoding a boolean variable to a constant value, remove the
|
||||
variable entirely and clean up all places where it's used rather than
|
||||
just setting it to a constant.
|
||||
|
||||
- name: Multi-tenant awareness
|
||||
description: >
|
||||
Code changes must consider both multi-tenant and single-tenant
|
||||
deployments. In multi-tenant mode, preserve tenant isolation, ensure
|
||||
tenant context is propagated correctly, and avoid assumptions that only
|
||||
hold for a single shared schema or globally shared state. In
|
||||
single-tenant mode, avoid introducing unnecessary tenant-specific
|
||||
requirements or cloud-only control-plane dependencies.
|
||||
|
||||
- name: Onyx lite compatibility
|
||||
description: >
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite
|
||||
deployments. Lite deployments disable the vector DB, Redis, model
|
||||
servers, and background workers by default, use PostgreSQL-backed
|
||||
cache/auth/file storage, and rely on the API server to handle
|
||||
background work. Do not assume those services are available unless the
|
||||
code path is explicitly limited to full deployments.
|
||||
|
||||
- name: OnyxError over HTTPException
|
||||
description: >
|
||||
Never raise HTTPException directly in business code. Use
|
||||
`raise OnyxError(OnyxErrorCode.XXX, "message")` from
|
||||
`onyx.error_handling.exceptions`. A global FastAPI exception handler
|
||||
converts OnyxError into structured JSON responses with
|
||||
{"error_code": "...", "detail": "..."}. Error codes are defined in
|
||||
`onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors
|
||||
with dynamic HTTP status codes, use `status_code_override`:
|
||||
`raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`.
|
||||
include:
|
||||
- backend/**/*.py
|
||||
|
||||
issues:
|
||||
fix_with_cubic_buttons: true
|
||||
pr_comment_fixes: true
|
||||
fix_commits_to_pr: true
|
||||
@@ -489,20 +489,18 @@ services:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
|
||||
@@ -290,25 +290,20 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
# sleep a little bit to allow the web_server / api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -314,21 +314,19 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/sslcerts:/etc/nginx/sslcerts
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -333,25 +333,20 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
# sleep a little bit to allow the web_server / api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -202,20 +202,18 @@ services:
|
||||
ports:
|
||||
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
|
||||
@@ -477,7 +477,10 @@ services:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
# Mount templates read-only; the startup command copies them into
|
||||
# the writable /etc/nginx/conf.d/ inside the container. This avoids
|
||||
# "Permission denied" errors on Windows Docker bind mounts.
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
|
||||
# - ../data/certbot/conf:/etc/letsencrypt
|
||||
# - ../data/certbot/www:/var/www/certbot
|
||||
@@ -489,12 +492,13 @@ services:
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not receive any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
# PRODUCTION: Change to app.conf.template.prod for production nginx config
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
|
||||
cache:
|
||||
image: redis:7.4-alpine
|
||||
|
||||
1131
deployment/docker_compose/install.ps1
Normal file
1131
deployment/docker_compose/install.ps1
Normal file
File diff suppressed because it is too large
Load Diff
@@ -96,8 +96,8 @@ fi
|
||||
|
||||
# When --lite is passed as a flag, lower resource thresholds early (before the
|
||||
# resource check). When lite is chosen interactively, the thresholds are adjusted
|
||||
# inside the new-deployment flow, after the resource check has already passed
|
||||
# with the standard thresholds — which is the safer direction.
|
||||
# after the resource check has already passed with the standard thresholds —
|
||||
# which is the safer direction.
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
@@ -110,9 +110,6 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
|
||||
# Build the -f flags for docker compose.
|
||||
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
|
||||
# (used by shutdown/delete-data so users don't need to remember --lite).
|
||||
# Without the argument, the lite overlay is only included when --lite was
|
||||
# explicitly passed — preventing install/start from silently staying in
|
||||
# lite mode just because the file exists on disk from a prior run.
|
||||
compose_file_args() {
|
||||
local auto_detect="${1:-false}"
|
||||
local args="-f docker-compose.yml"
|
||||
@@ -177,7 +174,7 @@ ensure_file() {
|
||||
|
||||
# --- Interactive prompt helpers ---
|
||||
is_interactive() {
|
||||
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
|
||||
[[ "$NO_PROMPT" = false ]]
|
||||
}
|
||||
|
||||
prompt_or_default() {
|
||||
@@ -207,6 +204,16 @@ prompt_yn_or_default() {
|
||||
fi
|
||||
}
|
||||
|
||||
confirm_action() {
|
||||
local description="$1"
|
||||
prompt_yn_or_default "Install ${description}? (Y/n) [default: Y] " "Y"
|
||||
if [[ "$REPLY" =~ ^[Nn] ]]; then
|
||||
print_warning "Skipping: ${description}"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
@@ -395,6 +402,11 @@ fi
|
||||
|
||||
if ! command -v docker &> /dev/null; then
|
||||
if [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; then
|
||||
print_info "Docker is required but not installed."
|
||||
if ! confirm_action "Docker Engine"; then
|
||||
print_error "Docker is required to run Onyx."
|
||||
exit 1
|
||||
fi
|
||||
install_docker_linux
|
||||
if ! command -v docker &> /dev/null; then
|
||||
print_error "Docker installation failed."
|
||||
@@ -411,7 +423,11 @@ if command -v docker &> /dev/null \
|
||||
&& ! command -v docker-compose &> /dev/null \
|
||||
&& { [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; }; then
|
||||
|
||||
print_info "Docker Compose not found — installing plugin..."
|
||||
print_info "Docker Compose is required but not installed."
|
||||
if ! confirm_action "Docker Compose plugin"; then
|
||||
print_error "Docker Compose is required to run Onyx."
|
||||
exit 1
|
||||
fi
|
||||
COMPOSE_ARCH="$(uname -m)"
|
||||
COMPOSE_URL="https://github.com/docker/compose/releases/latest/download/docker-compose-linux-${COMPOSE_ARCH}"
|
||||
COMPOSE_DIR="/usr/local/lib/docker/cli-plugins"
|
||||
@@ -562,10 +578,31 @@ version_compare() {
|
||||
|
||||
# Check Docker daemon
|
||||
if ! docker info &> /dev/null; then
|
||||
print_error "Docker daemon is not running. Please start Docker."
|
||||
exit 1
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
print_info "Docker daemon is not running. Starting Docker Desktop..."
|
||||
open -a Docker
|
||||
# Wait up to 120 seconds for Docker to be ready
|
||||
DOCKER_WAIT=0
|
||||
DOCKER_MAX_WAIT=120
|
||||
while ! docker info &> /dev/null; do
|
||||
if [ $DOCKER_WAIT -ge $DOCKER_MAX_WAIT ]; then
|
||||
print_error "Docker Desktop did not start within ${DOCKER_MAX_WAIT} seconds."
|
||||
print_info "Please start Docker Desktop manually and re-run this script."
|
||||
exit 1
|
||||
fi
|
||||
printf "\r\033[KWaiting for Docker Desktop to start... (%ds)" "$DOCKER_WAIT"
|
||||
sleep 2
|
||||
DOCKER_WAIT=$((DOCKER_WAIT + 2))
|
||||
done
|
||||
echo ""
|
||||
print_success "Docker Desktop is now running"
|
||||
else
|
||||
print_error "Docker daemon is not running. Please start Docker."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
print_success "Docker daemon is running"
|
||||
fi
|
||||
print_success "Docker daemon is running"
|
||||
|
||||
# Check Docker resources
|
||||
print_step "Verifying Docker resources"
|
||||
@@ -705,25 +742,48 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
|
||||
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
|
||||
fi
|
||||
|
||||
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
|
||||
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
|
||||
if [[ "$LITE_MODE" = false ]]; then
|
||||
print_info "Which deployment mode would you like?"
|
||||
echo ""
|
||||
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
echo " LLM chat, tools, file uploads, and Projects still work"
|
||||
echo " 2) Standard - Full deployment with search, connectors, and RAG"
|
||||
echo ""
|
||||
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
|
||||
echo ""
|
||||
|
||||
case "$REPLY" in
|
||||
2)
|
||||
print_info "Selected: Standard mode"
|
||||
;;
|
||||
*)
|
||||
LITE_MODE=true
|
||||
print_info "Selected: Lite mode"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
print_info "Deployment mode: Lite (set via --lite flag)"
|
||||
fi
|
||||
|
||||
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
|
||||
print_error "--include-craft cannot be used with Lite mode."
|
||||
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
fi
|
||||
|
||||
# Handle lite overlay file based on selected mode
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
|
||||
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
|
||||
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
|
||||
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
|
||||
print_warning "Existing lite overlay found but --lite was not passed."
|
||||
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
|
||||
LITE_MODE=true
|
||||
else
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed lite overlay (switching to standard mode)"
|
||||
fi
|
||||
else
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed previous lite overlay (switching to standard mode)"
|
||||
fi
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed previous lite overlay (switching to standard mode)"
|
||||
fi
|
||||
|
||||
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
|
||||
@@ -745,6 +805,7 @@ print_success "All configuration files ready"
|
||||
# Set up deployment configuration
|
||||
print_step "Setting up deployment configs"
|
||||
ENV_FILE="${INSTALL_ROOT}/deployment/.env"
|
||||
ENV_TEMPLATE="${INSTALL_ROOT}/deployment/env.template"
|
||||
# Check if services are already running
|
||||
if [ -d "${INSTALL_ROOT}/deployment" ] && [ -f "${INSTALL_ROOT}/deployment/docker-compose.yml" ]; then
|
||||
# Determine compose command
|
||||
@@ -785,22 +846,22 @@ if [ -f "$ENV_FILE" ]; then
|
||||
if [ "$REPLY" = "update" ]; then
|
||||
print_info "Update selected. Which tag would you like to deploy?"
|
||||
echo ""
|
||||
echo "• Press Enter for latest (recommended)"
|
||||
echo "• Press Enter for edge (recommended)"
|
||||
echo "• Type a specific tag (e.g., v0.1.0)"
|
||||
echo ""
|
||||
if [ "$INCLUDE_CRAFT" = true ]; then
|
||||
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
|
||||
VERSION="$REPLY"
|
||||
else
|
||||
prompt_or_default "Enter tag [default: latest]: " "latest"
|
||||
prompt_or_default "Enter tag [default: edge]: " "edge"
|
||||
VERSION="$REPLY"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
|
||||
print_info "Selected: craft-latest (Craft enabled)"
|
||||
elif [ "$VERSION" = "latest" ]; then
|
||||
print_info "Selected: Latest version"
|
||||
elif [ "$VERSION" = "edge" ]; then
|
||||
print_info "Selected: edge (latest nightly)"
|
||||
else
|
||||
print_info "Selected: $VERSION"
|
||||
fi
|
||||
@@ -852,45 +913,6 @@ else
|
||||
print_info "No existing .env file found. Setting up new deployment..."
|
||||
echo ""
|
||||
|
||||
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
|
||||
if [[ "$LITE_MODE" = false ]]; then
|
||||
print_info "Which deployment mode would you like?"
|
||||
echo ""
|
||||
echo " 1) Standard - Full deployment with search, connectors, and RAG"
|
||||
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
echo " LLM chat, tools, file uploads, and Projects still work"
|
||||
echo ""
|
||||
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
|
||||
echo ""
|
||||
|
||||
case "$REPLY" in
|
||||
2)
|
||||
LITE_MODE=true
|
||||
print_info "Selected: Lite mode"
|
||||
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
|
||||
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
|
||||
;;
|
||||
*)
|
||||
print_info "Selected: Standard mode"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
print_info "Deployment mode: Lite (set via --lite flag)"
|
||||
fi
|
||||
|
||||
# Validate lite + craft combination (could now be set interactively)
|
||||
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
|
||||
print_error "--include-craft cannot be used with Lite mode."
|
||||
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Adjust resource expectations for lite mode
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
fi
|
||||
|
||||
# Ask for version
|
||||
print_info "Which tag would you like to deploy?"
|
||||
echo ""
|
||||
@@ -901,18 +923,18 @@ else
|
||||
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
|
||||
VERSION="$REPLY"
|
||||
else
|
||||
echo "• Press Enter for latest (recommended)"
|
||||
echo "• Press Enter for edge (recommended)"
|
||||
echo "• Type a specific tag (e.g., v0.1.0)"
|
||||
echo ""
|
||||
prompt_or_default "Enter tag [default: latest]: " "latest"
|
||||
prompt_or_default "Enter tag [default: edge]: " "edge"
|
||||
VERSION="$REPLY"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
|
||||
print_info "Selected: craft-latest (Craft enabled)"
|
||||
elif [ "$VERSION" = "latest" ]; then
|
||||
print_info "Selected: Latest tag"
|
||||
elif [ "$VERSION" = "edge" ]; then
|
||||
print_info "Selected: edge (latest nightly)"
|
||||
else
|
||||
print_info "Selected: $VERSION"
|
||||
fi
|
||||
@@ -1070,20 +1092,39 @@ fi
|
||||
export HOST_PORT=$AVAILABLE_PORT
|
||||
print_success "Using port $AVAILABLE_PORT for nginx"
|
||||
|
||||
# Determine if we're using the latest tag or a craft tag (both should force pull)
|
||||
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
|
||||
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
|
||||
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
|
||||
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
USE_LATEST=true
|
||||
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
|
||||
else
|
||||
print_info "Using 'latest' tag - will force pull and recreate containers"
|
||||
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
|
||||
fi
|
||||
else
|
||||
USE_LATEST=false
|
||||
fi
|
||||
|
||||
# For pinned version tags, re-download config files from that tag so the
|
||||
# compose file matches the images being pulled (the initial download used main).
|
||||
if [[ "$USE_LATEST" = false ]] && [[ "$USE_LOCAL_FILES" = false ]]; then
|
||||
PINNED_BASE="https://raw.githubusercontent.com/onyx-dot-app/onyx/${CURRENT_IMAGE_TAG}/deployment"
|
||||
print_info "Fetching config files matching tag ${CURRENT_IMAGE_TAG}..."
|
||||
if download_file "${PINNED_BASE}/docker_compose/docker-compose.yml" "${INSTALL_ROOT}/deployment/docker-compose.yml" 2>/dev/null; then
|
||||
download_file "${PINNED_BASE}/data/nginx/app.conf.template" "${INSTALL_ROOT}/data/nginx/app.conf.template" 2>/dev/null || true
|
||||
download_file "${PINNED_BASE}/data/nginx/run-nginx.sh" "${INSTALL_ROOT}/data/nginx/run-nginx.sh" 2>/dev/null || true
|
||||
chmod +x "${INSTALL_ROOT}/data/nginx/run-nginx.sh"
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
download_file "${PINNED_BASE}/docker_compose/${LITE_COMPOSE_FILE}" \
|
||||
"${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" 2>/dev/null || true
|
||||
fi
|
||||
print_success "Config files updated to match ${CURRENT_IMAGE_TAG}"
|
||||
else
|
||||
print_warning "Tag ${CURRENT_IMAGE_TAG} not found on GitHub — using main branch configs"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Pull Docker images with reduced output
|
||||
print_step "Pulling Docker images"
|
||||
print_info "This may take several minutes depending on your internet connection..."
|
||||
|
||||
@@ -127,6 +127,7 @@ Inputs (common):
|
||||
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
|
||||
- `postgres_username`, `postgres_password`
|
||||
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
|
||||
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
|
||||
|
||||
### `vpc`
|
||||
- Builds a VPC sized for EKS with multiple private and public subnets
|
||||
|
||||
@@ -88,6 +88,8 @@ module "waf" {
|
||||
tags = local.merged_tags
|
||||
|
||||
# WAF configuration with sensible defaults
|
||||
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
|
||||
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
|
||||
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
|
||||
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
|
||||
geo_restriction_countries = var.waf_geo_restriction_countries
|
||||
|
||||
@@ -117,6 +117,18 @@ variable "waf_rate_limit_requests_per_5_minutes" {
|
||||
default = 2000
|
||||
}
|
||||
|
||||
variable "waf_allowed_ip_cidrs" {
|
||||
type = list(string)
|
||||
description = "Optional IPv4 CIDR ranges allowed through the WAF. Leave empty to disable IP allowlisting."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "waf_common_rule_set_count_rules" {
|
||||
type = list(string)
|
||||
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "waf_api_rate_limit_requests_per_5_minutes" {
|
||||
type = number
|
||||
description = "Rate limit for API requests per 5 minutes per IP address"
|
||||
|
||||
@@ -1,6 +1,20 @@
|
||||
locals {
|
||||
name = var.name
|
||||
tags = var.tags
|
||||
name = var.name
|
||||
tags = var.tags
|
||||
ip_allowlist_enabled = length(var.allowed_ip_cidrs) > 0
|
||||
managed_rule_priority = local.ip_allowlist_enabled ? 1 : 0
|
||||
}
|
||||
|
||||
resource "aws_wafv2_ip_set" "allowed_ips" {
|
||||
count = local.ip_allowlist_enabled ? 1 : 0
|
||||
|
||||
name = "${local.name}-allowed-ips"
|
||||
description = "IP allowlist for ${local.name}"
|
||||
scope = "REGIONAL"
|
||||
ip_address_version = "IPV4"
|
||||
addresses = var.allowed_ip_cidrs
|
||||
|
||||
tags = local.tags
|
||||
}
|
||||
|
||||
# AWS WAFv2 Web ACL
|
||||
@@ -13,10 +27,38 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
allow {}
|
||||
}
|
||||
|
||||
dynamic "rule" {
|
||||
for_each = local.ip_allowlist_enabled ? [1] : []
|
||||
content {
|
||||
name = "BlockRequestsOutsideAllowedIPs"
|
||||
priority = 1
|
||||
|
||||
action {
|
||||
block {}
|
||||
}
|
||||
|
||||
statement {
|
||||
not_statement {
|
||||
statement {
|
||||
ip_set_reference_statement {
|
||||
arn = aws_wafv2_ip_set.allowed_ips[0].arn
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
visibility_config {
|
||||
cloudwatch_metrics_enabled = true
|
||||
metric_name = "BlockRequestsOutsideAllowedIPsMetric"
|
||||
sampled_requests_enabled = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# AWS Managed Rules - Core Rule Set
|
||||
rule {
|
||||
name = "AWSManagedRulesCommonRuleSet"
|
||||
priority = 1
|
||||
priority = 1 + local.managed_rule_priority
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -26,6 +68,16 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
managed_rule_group_statement {
|
||||
name = "AWSManagedRulesCommonRuleSet"
|
||||
vendor_name = "AWS"
|
||||
|
||||
dynamic "rule_action_override" {
|
||||
for_each = var.common_rule_set_count_rules
|
||||
content {
|
||||
name = rule_action_override.value
|
||||
action_to_use {
|
||||
count {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +91,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# AWS Managed Rules - Known Bad Inputs
|
||||
rule {
|
||||
name = "AWSManagedRulesKnownBadInputsRuleSet"
|
||||
priority = 2
|
||||
priority = 2 + local.managed_rule_priority
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -62,7 +114,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# Rate Limiting Rule
|
||||
rule {
|
||||
name = "RateLimitRule"
|
||||
priority = 3
|
||||
priority = 3 + local.managed_rule_priority
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -87,7 +139,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
for_each = length(var.geo_restriction_countries) > 0 ? [1] : []
|
||||
content {
|
||||
name = "GeoRestrictionRule"
|
||||
priority = 4
|
||||
priority = 4 + local.managed_rule_priority
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -110,7 +162,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# IP Rate Limiting
|
||||
rule {
|
||||
name = "APIRateLimitRule"
|
||||
priority = 5
|
||||
priority = 5 + local.managed_rule_priority
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -133,7 +185,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# SQL Injection Protection
|
||||
rule {
|
||||
name = "AWSManagedRulesSQLiRuleSet"
|
||||
priority = 6
|
||||
priority = 6 + local.managed_rule_priority
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -156,7 +208,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# Anonymous IP Protection
|
||||
rule {
|
||||
name = "AWSManagedRulesAnonymousIpList"
|
||||
priority = 7
|
||||
priority = 7 + local.managed_rule_priority
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
|
||||
@@ -9,6 +9,18 @@ variable "tags" {
|
||||
default = {}
|
||||
}
|
||||
|
||||
variable "allowed_ip_cidrs" {
|
||||
type = list(string)
|
||||
description = "Optional IPv4 CIDR ranges allowed to reach the application. Leave empty to disable IP allowlisting."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "common_rule_set_count_rules" {
|
||||
type = list(string)
|
||||
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "rate_limit_requests_per_5_minutes" {
|
||||
type = number
|
||||
description = "Rate limit for requests per 5 minutes per IP address"
|
||||
|
||||
1
desktop/AGENTS.md
Symbolic link
1
desktop/AGENTS.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../web/AGENTS.md
|
||||
1
desktop/CLAUDE.md
Symbolic link
1
desktop/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
@@ -65,7 +65,7 @@
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/STANDARDS.md file."
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
@@ -85,7 +85,7 @@
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
type RunCIOptions struct {
|
||||
DryRun bool
|
||||
Yes bool
|
||||
Rerun bool
|
||||
}
|
||||
|
||||
// NewRunCICommand creates a new run-ci command
|
||||
@@ -49,6 +50,7 @@ Example usage:
|
||||
|
||||
cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, "Perform all local operations but skip pushing to remote and creating PRs")
|
||||
cmd.Flags().BoolVar(&opts.Yes, "yes", false, "Skip confirmation prompts and automatically proceed")
|
||||
cmd.Flags().BoolVar(&opts.Rerun, "rerun", false, "Update an existing CI PR with the latest fork changes to re-trigger CI")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -107,19 +109,44 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
log.Fatalf("PR #%s is not from a fork - CI should already run automatically", prNumber)
|
||||
}
|
||||
|
||||
// Confirm before proceeding
|
||||
if !opts.Yes {
|
||||
if !prompt.Confirm(fmt.Sprintf("Create CI branch for PR #%s? (yes/no): ", prNumber)) {
|
||||
log.Info("Exiting...")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create the CI branch
|
||||
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
|
||||
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
|
||||
// Check if a CI PR already exists for this branch
|
||||
existingPRURL, err := findExistingCIPR(ciBranch)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to check for existing CI PR: %v", err)
|
||||
}
|
||||
|
||||
if existingPRURL != "" && !opts.Rerun {
|
||||
log.Infof("A CI PR already exists for #%s: %s", prNumber, existingPRURL)
|
||||
log.Info("Run with --rerun to update it with the latest fork changes and re-trigger CI.")
|
||||
return
|
||||
}
|
||||
|
||||
if opts.Rerun && existingPRURL == "" {
|
||||
log.Warn("--rerun was specified but no existing open CI PR was found. A new PR will be created.")
|
||||
}
|
||||
|
||||
if existingPRURL != "" && opts.Rerun {
|
||||
log.Infof("Existing CI PR found: %s", existingPRURL)
|
||||
log.Info("Will update the CI branch with the latest fork changes to re-trigger CI.")
|
||||
}
|
||||
|
||||
// Confirm before proceeding
|
||||
if !opts.Yes {
|
||||
action := "Create CI branch"
|
||||
if existingPRURL != "" {
|
||||
action = "Update existing CI branch"
|
||||
}
|
||||
if !prompt.Confirm(fmt.Sprintf("%s for PR #%s? (yes/no): ", action, prNumber)) {
|
||||
log.Info("Exiting...")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the fork's branch
|
||||
if forkRepo == "" {
|
||||
log.Fatalf("Could not determine fork repository - headRepositoryOwner or headRepository.name is empty")
|
||||
@@ -158,7 +185,11 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
|
||||
if opts.DryRun {
|
||||
log.Warnf("[DRY RUN] Would push CI branch: %s", ciBranch)
|
||||
log.Warnf("[DRY RUN] Would create PR: %s", prTitle)
|
||||
if existingPRURL == "" {
|
||||
log.Warnf("[DRY RUN] Would create PR: %s", prTitle)
|
||||
} else {
|
||||
log.Warnf("[DRY RUN] Would update existing PR: %s", existingPRURL)
|
||||
}
|
||||
// Switch back to original branch
|
||||
if err := git.RunCommand("switch", "--quiet", originalBranch); err != nil {
|
||||
log.Warnf("Failed to switch back to original branch: %v", err)
|
||||
@@ -176,6 +207,17 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
log.Fatalf("Failed to push CI branch: %v", err)
|
||||
}
|
||||
|
||||
if existingPRURL != "" {
|
||||
// PR already exists - force push is enough to re-trigger CI
|
||||
log.Infof("Switching back to original branch: %s", originalBranch)
|
||||
if err := git.RunCommand("switch", "--quiet", originalBranch); err != nil {
|
||||
log.Warnf("Failed to switch back to original branch: %v", err)
|
||||
}
|
||||
log.Infof("CI PR updated successfully: %s", existingPRURL)
|
||||
log.Info("The force push will re-trigger CI. Remember to close (not merge) this PR after CI completes!")
|
||||
return
|
||||
}
|
||||
|
||||
// Create PR using GitHub CLI
|
||||
log.Info("Creating PR...")
|
||||
prURL, err := createCIPR(ciBranch, prInfo.BaseRefName, prTitle, prBody)
|
||||
@@ -217,6 +259,39 @@ func getPRInfo(prNumber string) (*PRInfo, error) {
|
||||
return &prInfo, nil
|
||||
}
|
||||
|
||||
// findExistingCIPR checks if an open PR already exists for the given CI branch.
|
||||
// Returns the PR URL if found, or empty string if not.
|
||||
func findExistingCIPR(headBranch string) (string, error) {
|
||||
cmd := exec.Command("gh", "pr", "list",
|
||||
"--head", headBranch,
|
||||
"--state", "open",
|
||||
"--json", "url",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("%w: %s", err, string(exitErr.Stderr))
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
var prs []struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
if err := json.Unmarshal(output, &prs); err != nil {
|
||||
log.Debugf("Failed to parse PR list JSON: %v (raw: %s)", err, string(output))
|
||||
return "", fmt.Errorf("failed to parse PR list: %w", err)
|
||||
}
|
||||
|
||||
if len(prs) == 0 {
|
||||
log.Debugf("No existing open PRs found for branch %s", headBranch)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
log.Debugf("Found existing PR for branch %s: %s", headBranch, prs[0].URL)
|
||||
return prs[0].URL, nil
|
||||
}
|
||||
|
||||
// createCIPR creates a pull request for CI using the GitHub CLI
|
||||
func createCIPR(headBranch, baseBranch, title, body string) (string, error) {
|
||||
cmd := exec.Command("gh", "pr", "create",
|
||||
|
||||
540
web/AGENTS.md
Normal file
540
web/AGENTS.md
Normal file
@@ -0,0 +1,540 @@
|
||||
# Frontend Standards
|
||||
|
||||
This file is the single source of truth for frontend coding standards across all Onyx frontend
|
||||
projects (including, but not limited to, `/web`, `/desktop`).
|
||||
|
||||
# Components
|
||||
|
||||
UI components are spread across several directories while the codebase migrates to Opal:
|
||||
|
||||
- **`web/lib/opal/src/`** — The Opal design system. Preferred for all new components.
|
||||
- **`web/src/refresh-components/`** — Production components not yet migrated to Opal.
|
||||
- **`web/src/sections/`** — Feature-specific composite components (cards, modals, etc.).
|
||||
- **`web/src/layouts/`** — Page-level layout components (settings pages, etc.).
|
||||
|
||||
**Do NOT use anything from `web/src/components/`** — this directory contains legacy components
|
||||
that are being phased out. Always prefer Opal first; fall back to `refresh-components` only for
|
||||
components not yet available in Opal.
|
||||
|
||||
## Opal Layouts (`lib/opal/src/layouts/`)
|
||||
|
||||
All layout primitives are imported from `@opal/layouts`. They handle sizing, font selection, icon
|
||||
alignment, and optional inline editing.
|
||||
|
||||
```typescript
|
||||
import { Content, ContentAction, IllustrationContent } from "@opal/layouts";
|
||||
```
|
||||
|
||||
### Content
|
||||
|
||||
**Use this for any combination of icon + title + description.**
|
||||
|
||||
A two-axis layout component that automatically routes to the correct internal layout
|
||||
(`ContentXl`, `ContentLg`, `ContentMd`, `ContentSm`) based on `sizePreset` and `variant`:
|
||||
|
||||
| sizePreset | variant | Routes to | Layout |
|
||||
|---|---|---|---|
|
||||
| `headline` / `section` | `heading` | `ContentXl` | Icon on top (flex-col) |
|
||||
| `headline` / `section` | `section` | `ContentLg` | Icon inline (flex-row) |
|
||||
| `main-content` / `main-ui` / `secondary` | `section` / `heading` | `ContentMd` | Compact inline |
|
||||
| `main-content` / `main-ui` / `secondary` | `body` | `ContentSm` | Body text layout |
|
||||
|
||||
```typescript
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgSettings}
|
||||
title="Settings"
|
||||
description="Manage your preferences"
|
||||
/>
|
||||
```
|
||||
|
||||
### ContentAction
|
||||
|
||||
**Use this when a Content block needs right-side actions** (buttons, badges, icons, etc.).
|
||||
|
||||
Wraps `Content` and adds a `rightChildren` slot. Accepts all `Content` props plus:
|
||||
- `rightChildren`: `ReactNode` — actions rendered on the right
|
||||
- `paddingVariant`: `SizeVariant` — controls outer padding
|
||||
|
||||
```typescript
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgUser}
|
||||
title="John Doe"
|
||||
description="Admin"
|
||||
rightChildren={<Button icon={SvgEdit}>Edit</Button>}
|
||||
/>
|
||||
```
|
||||
|
||||
### IllustrationContent
|
||||
|
||||
**Use this for empty states, error pages, and informational placeholders.**
|
||||
|
||||
A vertically-stacked, center-aligned layout that pairs a large illustration (7.5rem x 7.5rem)
|
||||
with a title and optional description.
|
||||
|
||||
```typescript
|
||||
import SvgNoResult from "@opal/illustrations/no-result";
|
||||
|
||||
<IllustrationContent
|
||||
illustration={SvgNoResult}
|
||||
title="No results found"
|
||||
description="Try adjusting your search or filters."
|
||||
/>
|
||||
```
|
||||
|
||||
Props:
|
||||
- `illustration`: `IconFunctionComponent` — optional, from `@opal/illustrations`
|
||||
- `title`: `string` — required
|
||||
- `description`: `string` — optional
|
||||
|
||||
## Settings Page Layout (`src/layouts/settings-layouts.tsx`)
|
||||
|
||||
**Use this for all admin/settings pages.** Provides a standardized layout with scroll-aware
|
||||
sticky headers, centered content containers, and responsive behavior.
|
||||
|
||||
```typescript
|
||||
import SettingsLayouts from "@/layouts/settings-layouts";
|
||||
|
||||
function MySettingsPage() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSettings}
|
||||
title="Account Settings"
|
||||
description="Manage your account preferences"
|
||||
rightChildren={<Button>Save</Button>}
|
||||
>
|
||||
<InputTypeIn placeholder="Search settings..." />
|
||||
</SettingsLayouts.Header>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<Card>Settings content here</Card>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
Sub-components:
|
||||
- **`SettingsLayouts.Root`** — Wrapper with centered, scrollable container. Width options:
|
||||
`"sm"` (672px), `"sm-md"` (752px), `"md"` (872px, default), `"lg"` (992px), `"full"` (100%).
|
||||
- **`SettingsLayouts.Header`** — Sticky header with icon, title, description, optional
|
||||
`rightChildren` actions, optional `children` below (e.g., search/filter), optional `backButton`,
|
||||
and optional `separator`. Automatically shows a scroll shadow when scrolled.
|
||||
- **`SettingsLayouts.Body`** — Content container with consistent padding and vertical spacing.
|
||||
|
||||
## Cards (`src/sections/cards/`)
|
||||
|
||||
**When building a card that displays information about a specific entity (agent, document set,
|
||||
file, connector, etc.), add it to `web/src/sections/cards/`.**
|
||||
|
||||
Each card is a self-contained component focused on a single entity type. Cards typically include
|
||||
entity identification (name, avatar, icon), summary information, and quick actions.
|
||||
|
||||
```typescript
|
||||
import AgentCard from "@/sections/cards/AgentCard";
|
||||
import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import FileCard from "@/sections/cards/FileCard";
|
||||
```
|
||||
|
||||
Guidelines:
|
||||
- One card per entity type — keep card-specific logic within the card component.
|
||||
- Cards should be reusable across different pages and contexts.
|
||||
- Use shared components from `@opal/components`, `@opal/layouts`, and `@/refresh-components`
|
||||
inside cards — do not duplicate layout or styling logic.
|
||||
|
||||
## Button (`components/buttons/button/`)
|
||||
|
||||
**Always use the Opal `Button`.** Do not use raw `<button>` elements.
|
||||
|
||||
Built on `Interactive.Stateless` > `Interactive.Container`, so it inherits the full color/state
|
||||
system automatically.
|
||||
|
||||
```typescript
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
|
||||
// Labeled button
|
||||
<Button variant="default" prominence="primary" icon={SvgPlus}>
|
||||
Create
|
||||
</Button>
|
||||
|
||||
// Icon-only button (omit children)
|
||||
<Button variant="default" prominence="tertiary" icon={SvgTrash} size="sm" />
|
||||
```
|
||||
|
||||
Key props:
|
||||
- `variant`: `"default"` | `"action"` | `"danger"` | `"none"`
|
||||
- `prominence`: `"primary"` | `"secondary"` | `"tertiary"` | `"internal"`
|
||||
- `size`: `"lg"` | `"md"` | `"sm"` | `"xs"` | `"2xs"` | `"fit"`
|
||||
- `icon`, `rightIcon`, `children`, `disabled`, `href`, `tooltip`
|
||||
|
||||
## Core Primitives (`core/`)
|
||||
|
||||
The `core/` directory contains the lowest-level building blocks that power all Opal components.
|
||||
**Most code should not interface with these directly** — use higher-level components like `Button`,
|
||||
`Content`, and `ContentAction` instead. These are documented here for understanding, not everyday use.
|
||||
|
||||
### Interactive (`core/interactive/`)
|
||||
|
||||
The foundational layer for all clickable/interactive surfaces. Defines the color matrix for
|
||||
hover, active, and disabled states.
|
||||
|
||||
- **`Interactive.Stateless`** — Color system for stateless elements (buttons, links). Applies
|
||||
variant/prominence/state combinations via CSS custom properties.
|
||||
- **`Interactive.Stateful`** — Color system for stateful elements (toggles, sidebar items, selects).
|
||||
Uses `state` (`"empty"` | `"filled"` | `"selected"`) instead of prominence.
|
||||
- **`Interactive.Container`** — Structural box providing height, rounding, padding, and border.
|
||||
Shared by both Stateless and Stateful. Renders as `<div>`, `<button>`, or `<Link>` depending
|
||||
on context.
|
||||
- **`Interactive.Foldable`** — Zero-width collapsible wrapper with CSS grid animation.
|
||||
|
||||
### Disabled (`core/disabled/`)
|
||||
|
||||
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
|
||||
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
|
||||
interactive descendants.
|
||||
|
||||
### Hoverable (`core/animations/`)
|
||||
|
||||
A standardized way to provide "opacity-100 on hover" behavior. Instead of manually wiring
|
||||
`opacity-0 group-hover:opacity-100` with Tailwind, use `Hoverable` for consistent, coordinated
|
||||
hover-to-reveal patterns.
|
||||
|
||||
- **`Hoverable.Root`** — Wraps a hover group. Tracks mouse enter/leave and broadcasts hover
|
||||
state to descendants via a per-group React context.
|
||||
- **`Hoverable.Item`** — Marks an element that should appear on hover. Supports two modes:
|
||||
- **Group mode** (`group` prop provided): visibility driven by a matching `Hoverable.Root`
|
||||
ancestor. Throws if no matching Root is found.
|
||||
- **Local mode** (`group` omitted): uses CSS `:hover` on the item itself.
|
||||
|
||||
```typescript
|
||||
import { Hoverable } from "@opal/core";
|
||||
|
||||
// Group mode — hovering anywhere on the row reveals the trash icon
|
||||
<Hoverable.Root group="row">
|
||||
<div className="flex items-center gap-2">
|
||||
<span>Row content</span>
|
||||
<Hoverable.Item group="row" variant="opacity-on-hover">
|
||||
<SvgTrash />
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
|
||||
// Local mode — hovering the item itself reveals it
|
||||
<Hoverable.Item variant="opacity-on-hover">
|
||||
<SvgTrash />
|
||||
</Hoverable.Item>
|
||||
```
|
||||
|
||||
# Best Practices
|
||||
|
||||
## 1. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 2. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
## 3. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 5. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
## 6. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
# Stylistic Preferences
|
||||
|
||||
## 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
## 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions. Keep prop interfaces in the same file
|
||||
as the component they belong to. Non-prop types (shared models, API response shapes, enums, etc.)
|
||||
should be placed in a co-located `interfaces.ts` file.**
|
||||
|
||||
**Reason:** Prop interfaces are tightly coupled to their component and rarely imported elsewhere,
|
||||
so co-location keeps things simple. Shared types belong in `interfaces.ts` so they can be
|
||||
imported without pulling in component code.
|
||||
|
||||
```typescript
|
||||
// ✅ Good — props interface in the same file as the component
|
||||
// UserCard.tsx
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ✅ Good — shared types in interfaces.ts
|
||||
// interfaces.ts
|
||||
export interface User {
|
||||
id: string
|
||||
name: string
|
||||
role: UserRole
|
||||
}
|
||||
|
||||
export type UserRole = "admin" | "member" | "viewer"
|
||||
|
||||
// ❌ Bad — inline prop types
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing. When a library component exposes a padding prop
|
||||
(e.g., `paddingVariant`), use that prop instead of wrapping it in a `<div>` with padding classes.
|
||||
If a library component does not expose a padding override and you find yourself adding a wrapper
|
||||
div for spacing, consider updating the library component to accept one.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins, and minimize wrapper
|
||||
divs that exist solely for spacing.
|
||||
|
||||
```typescript
|
||||
// ✅ Good — use the component's padding prop
|
||||
<ContentAction paddingVariant="md" ... />
|
||||
|
||||
// ✅ Good — padding utilities when no component prop exists
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad — wrapper div just for spacing
|
||||
<div className="p-4">
|
||||
<ContentAction ... />
|
||||
</div>
|
||||
|
||||
// ❌ Bad — margins
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
## 5. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 6. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
1
web/CLAUDE.md
Symbolic link
1
web/CLAUDE.md
Symbolic link
@@ -0,0 +1 @@
|
||||
AGENTS.md
|
||||
281
web/STANDARDS.md
281
web/STANDARDS.md
@@ -1,281 +0,0 @@
|
||||
# Web Standards
|
||||
|
||||
This document outlines the coding standards and best practices for the `web` directory Next.js project.
|
||||
|
||||
## 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
## 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
## 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
## 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
## 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
## 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
162
web/lib/opal/src/core/animations/Hoverable.stories.tsx
Normal file
162
web/lib/opal/src/core/animations/Hoverable.stories.tsx
Normal file
@@ -0,0 +1,162 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Hoverable } from "@opal/core";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Meta
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const meta: Meta = {
|
||||
title: "Core/Hoverable",
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stories
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Group mode — hovering the root reveals hidden items. */
|
||||
export const GroupMode: StoryObj = {
|
||||
render: () => (
|
||||
<Hoverable.Root group="demo">
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: "0.75rem",
|
||||
padding: "1rem",
|
||||
border: "1px solid var(--border-02)",
|
||||
borderRadius: "0.5rem",
|
||||
minWidth: 260,
|
||||
}}
|
||||
>
|
||||
<span style={{ color: "var(--text-01)" }}>Hover this card</span>
|
||||
<Hoverable.Item group="demo" variant="opacity-on-hover">
|
||||
<span style={{ color: "var(--text-03)" }}>✓ Revealed</span>
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
),
|
||||
};
|
||||
|
||||
/** Local mode — hovering the item itself reveals it (no Root needed). */
|
||||
export const LocalMode: StoryObj = {
|
||||
render: () => (
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: "0.75rem",
|
||||
padding: "1rem",
|
||||
}}
|
||||
>
|
||||
<span style={{ color: "var(--text-01)" }}>Hover the icon →</span>
|
||||
<Hoverable.Item variant="opacity-on-hover">
|
||||
<span style={{ fontSize: "1.25rem" }}>🗑</span>
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
/** Multiple independent groups on the same page. */
|
||||
export const MultipleGroups: StoryObj = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", flexDirection: "column", gap: "0.75rem" }}>
|
||||
{(["alpha", "beta"] as const).map((group) => (
|
||||
<Hoverable.Root key={group} group={group}>
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: "0.75rem",
|
||||
padding: "1rem",
|
||||
border: "1px solid var(--border-02)",
|
||||
borderRadius: "0.5rem",
|
||||
}}
|
||||
>
|
||||
<span style={{ color: "var(--text-01)" }}>Group: {group}</span>
|
||||
<Hoverable.Item group={group} variant="opacity-on-hover">
|
||||
<span style={{ color: "var(--text-03)" }}>✓ Revealed</span>
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
/** Multiple items revealed by a single root. */
|
||||
export const MultipleItems: StoryObj = {
|
||||
render: () => (
|
||||
<Hoverable.Root group="multi">
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: "0.75rem",
|
||||
padding: "1rem",
|
||||
border: "1px solid var(--border-02)",
|
||||
borderRadius: "0.5rem",
|
||||
}}
|
||||
>
|
||||
<span style={{ color: "var(--text-01)" }}>Hover to reveal all</span>
|
||||
<Hoverable.Item group="multi" variant="opacity-on-hover">
|
||||
<span>Edit</span>
|
||||
</Hoverable.Item>
|
||||
<Hoverable.Item group="multi" variant="opacity-on-hover">
|
||||
<span>Delete</span>
|
||||
</Hoverable.Item>
|
||||
<Hoverable.Item group="multi" variant="opacity-on-hover">
|
||||
<span>Share</span>
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
),
|
||||
};
|
||||
|
||||
/** Nested groups — inner and outer hover independently. */
|
||||
export const NestedGroups: StoryObj = {
|
||||
render: () => (
|
||||
<Hoverable.Root group="outer">
|
||||
<div
|
||||
style={{
|
||||
padding: "1rem",
|
||||
border: "1px solid var(--border-02)",
|
||||
borderRadius: "0.5rem",
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
gap: "0.75rem",
|
||||
}}
|
||||
>
|
||||
<div style={{ display: "flex", alignItems: "center", gap: "0.75rem" }}>
|
||||
<span style={{ color: "var(--text-01)" }}>Outer card</span>
|
||||
<Hoverable.Item group="outer" variant="opacity-on-hover">
|
||||
<span style={{ color: "var(--text-03)" }}>Outer action</span>
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
|
||||
<Hoverable.Root group="inner">
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: "0.75rem",
|
||||
padding: "0.75rem",
|
||||
border: "1px solid var(--border-03)",
|
||||
borderRadius: "0.375rem",
|
||||
}}
|
||||
>
|
||||
<span style={{ color: "var(--text-02)" }}>Inner card</span>
|
||||
<Hoverable.Item group="inner" variant="opacity-on-hover">
|
||||
<span style={{ color: "var(--text-03)" }}>Inner action</span>
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
),
|
||||
};
|
||||
@@ -1,127 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Hoverable } from "@opal/core";
|
||||
import SvgX from "@opal/icons/x";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared styles
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const cardStyle: React.CSSProperties = {
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "space-between",
|
||||
gap: "1rem",
|
||||
padding: "0.75rem 1rem",
|
||||
borderRadius: "0.5rem",
|
||||
border: "1px solid var(--border-02)",
|
||||
background: "var(--background-neutral-01)",
|
||||
minWidth: 220,
|
||||
};
|
||||
|
||||
const labelStyle: React.CSSProperties = {
|
||||
fontSize: "0.875rem",
|
||||
fontWeight: 500,
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Meta
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const meta: Meta = {
|
||||
title: "Core/Hoverable",
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
};
|
||||
|
||||
export default meta;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stories
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Local hover mode -- no `group` prop on the Item.
|
||||
* The icon only appears when you hover directly over the Item element itself.
|
||||
*/
|
||||
export const LocalHover: StoryObj = {
|
||||
render: () => (
|
||||
<div style={cardStyle}>
|
||||
<span style={labelStyle}>Hover this card area</span>
|
||||
|
||||
<Hoverable.Item variant="opacity-on-hover">
|
||||
<SvgX width={16} height={16} />
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
/**
|
||||
* Group hover mode -- hovering anywhere inside the Root reveals the Item.
|
||||
*/
|
||||
export const GroupHover: StoryObj = {
|
||||
render: () => (
|
||||
<Hoverable.Root group="card">
|
||||
<div style={cardStyle}>
|
||||
<span style={labelStyle}>Hover anywhere on this card</span>
|
||||
|
||||
<Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
<SvgX width={16} height={16} />
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
),
|
||||
};
|
||||
|
||||
/**
|
||||
* Nested groups demonstrating isolation.
|
||||
*
|
||||
* - Hovering the outer card reveals only the outer icon.
|
||||
* - Hovering the inner card reveals only the inner icon.
|
||||
*/
|
||||
export const NestedGroups: StoryObj = {
|
||||
render: () => (
|
||||
<Hoverable.Root group="outer">
|
||||
<div
|
||||
style={{
|
||||
...cardStyle,
|
||||
flexDirection: "column",
|
||||
alignItems: "stretch",
|
||||
gap: "0.75rem",
|
||||
padding: "1rem",
|
||||
minWidth: 300,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "space-between",
|
||||
}}
|
||||
>
|
||||
<span style={labelStyle}>Outer card</span>
|
||||
|
||||
<Hoverable.Item group="outer" variant="opacity-on-hover">
|
||||
<SvgX width={16} height={16} />
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
|
||||
<Hoverable.Root group="inner">
|
||||
<div
|
||||
style={{
|
||||
...cardStyle,
|
||||
background: "var(--background-neutral-02)",
|
||||
}}
|
||||
>
|
||||
<span style={labelStyle}>Inner card</span>
|
||||
|
||||
<Hoverable.Item group="inner" variant="opacity-on-hover">
|
||||
<SvgX width={16} height={16} />
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
),
|
||||
};
|
||||
@@ -20,10 +20,15 @@
|
||||
|
||||
"use client";
|
||||
|
||||
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
|
||||
import {
|
||||
cn,
|
||||
ensureHrefProtocol,
|
||||
INTERACTIVE_SELECTOR,
|
||||
noProp,
|
||||
} from "@/lib/utils";
|
||||
import type { Components } from "react-markdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useCallback, useMemo, useState, useEffect } from "react";
|
||||
import { useCallback, useMemo, useRef, useState, useEffect } from "react";
|
||||
import { useAppBackground } from "@/providers/AppBackgroundProvider";
|
||||
import { useTheme } from "next-themes";
|
||||
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
|
||||
@@ -532,6 +537,37 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
const { isSafari } = useBrowserInfo();
|
||||
const isLightMode = resolvedTheme === "light";
|
||||
const showBackground = hasBackground && enableBackground;
|
||||
|
||||
// Track whether the chat input was focused before a mousedown, so we can
|
||||
// restore focus on mouseup if no text was selected. This preserves
|
||||
// click-drag text selection while keeping the input focused on plain clicks.
|
||||
const inputWasFocused = useRef(false);
|
||||
|
||||
const handleMouseDown = useCallback(
|
||||
(event: React.MouseEvent<HTMLDivElement>) => {
|
||||
const activeEl = document.activeElement;
|
||||
const isFocused =
|
||||
activeEl instanceof HTMLElement &&
|
||||
activeEl.id === "onyx-chat-input-textarea";
|
||||
const target = event.target;
|
||||
const isInteractive =
|
||||
target instanceof HTMLElement && !!target.closest(INTERACTIVE_SELECTOR);
|
||||
inputWasFocused.current = isFocused && !isInteractive;
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleMouseUp = useCallback(() => {
|
||||
if (!inputWasFocused.current) return;
|
||||
inputWasFocused.current = false;
|
||||
const sel = window.getSelection();
|
||||
if (sel && !sel.isCollapsed) return;
|
||||
const textarea = document.getElementById("onyx-chat-input-textarea");
|
||||
// Only restore focus if no other element has grabbed it since mousedown.
|
||||
if (textarea && document.activeElement !== textarea) {
|
||||
textarea.focus();
|
||||
}
|
||||
}, []);
|
||||
const horizontalBlurMask = `linear-gradient(
|
||||
to right,
|
||||
transparent 0%,
|
||||
@@ -549,6 +585,8 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
*/
|
||||
<div
|
||||
data-main-container
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseUp={handleMouseUp}
|
||||
className={cn(
|
||||
"@container flex flex-col h-full w-full relative overflow-hidden",
|
||||
showBackground && "bg-cover bg-center bg-fixed"
|
||||
|
||||
@@ -13,6 +13,9 @@ import { ALLOWED_URL_PROTOCOLS } from "./constants";
|
||||
const URI_SCHEME_REGEX = /^[a-zA-Z][a-zA-Z\d+.-]*:/;
|
||||
const BARE_EMAIL_REGEX = /^[^\s@/]+@[^\s@/:]+\.[^\s@/:]+$/;
|
||||
|
||||
export const INTERACTIVE_SELECTOR =
|
||||
"a, button, input, textarea, select, label, [role='button'], [tabindex]:not([tabindex='-1']), [contenteditable]:not([contenteditable='false'])";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
|
||||
@@ -451,13 +451,19 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
);
|
||||
|
||||
return (
|
||||
<Section ref={ref} padding={1} alignItems="start" height="fit" {...props}>
|
||||
<Section
|
||||
ref={ref}
|
||||
padding={0.5}
|
||||
alignItems="start"
|
||||
height="fit"
|
||||
{...props}
|
||||
>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="start"
|
||||
gap={0}
|
||||
padding={0}
|
||||
padding={0.5}
|
||||
>
|
||||
<div className="relative w-full">
|
||||
{/* Close button is absolutely positioned because:
|
||||
@@ -485,7 +491,6 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
</DialogPrimitive.Title>
|
||||
</div>
|
||||
</Section>
|
||||
|
||||
{children}
|
||||
</Section>
|
||||
);
|
||||
|
||||
@@ -112,9 +112,11 @@ function MemoryItem({
|
||||
/>
|
||||
</Disabled>
|
||||
</Section>
|
||||
{isFocused && (
|
||||
<div
|
||||
className={isFocused ? "visible" : "invisible h-0 overflow-hidden"}
|
||||
>
|
||||
<CharacterCount value={memory.content} limit={MAX_MEMORY_LENGTH} />
|
||||
)}
|
||||
</div>
|
||||
</Section>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -182,7 +182,7 @@ const ProjectFolderButton = memo(({ project }: ProjectFolderButtonProps) => {
|
||||
onClose={() => setIsEditing(false)}
|
||||
/>
|
||||
) : (
|
||||
<Truncated>{project.name}</Truncated>
|
||||
<Truncated text03>{project.name}</Truncated>
|
||||
)}
|
||||
</SidebarTab>
|
||||
</Popover.Anchor>
|
||||
|
||||
47
web/tests/e2e/chat/input_focus_retention.spec.ts
Normal file
47
web/tests/e2e/chat/input_focus_retention.spec.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.describe(`Chat Input Focus Retention`, () => {
|
||||
test.beforeEach(async ({ page }, testInfo) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
});
|
||||
|
||||
test("clicking empty space retains focus on chat input", async ({ page }) => {
|
||||
const textarea = page.locator("#onyx-chat-input-textarea");
|
||||
await textarea.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Focus the textarea and type something
|
||||
await textarea.focus();
|
||||
await textarea.fill("test message");
|
||||
await expect(textarea).toBeFocused();
|
||||
|
||||
// Click on the main container's empty space (top-left corner)
|
||||
const container = page.locator("[data-main-container]");
|
||||
await container.click({ position: { x: 10, y: 10 } });
|
||||
|
||||
// Focus should remain on the textarea
|
||||
await expect(textarea).toBeFocused();
|
||||
});
|
||||
|
||||
test("clicking interactive elements still moves focus away", async ({
|
||||
page,
|
||||
}) => {
|
||||
const textarea = page.locator("#onyx-chat-input-textarea");
|
||||
await textarea.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Focus the textarea
|
||||
await textarea.focus();
|
||||
await expect(textarea).toBeFocused();
|
||||
|
||||
// Click on an interactive element inside the container
|
||||
const button = page.locator("[data-main-container] button").first();
|
||||
await button.waitFor({ state: "visible", timeout: 5000 });
|
||||
await button.click();
|
||||
|
||||
// Focus should have moved away from the textarea
|
||||
await expect(textarea).not.toBeFocused();
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user