mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-23 16:42:42 +00:00
Compare commits
1 Commits
fix/table-
...
jamison/de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a64c7c8ff |
279
AGENTS.md
279
AGENTS.md
@@ -167,7 +167,284 @@ web/
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
|
||||
@@ -47,8 +47,6 @@ RUN apt-get update && \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
# Install procps so kubernetes exec sessions can use ps aux for debugging
|
||||
procps \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
@@ -25,6 +25,9 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
@@ -55,7 +58,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
@@ -71,7 +74,9 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
num_available_tenants = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
|
||||
num_minimum_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
if num_available_tenants < num_minimum_available_tenants:
|
||||
@@ -93,12 +98,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
try:
|
||||
lock_check.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release check lock (likely expired), continuing"
|
||||
)
|
||||
lock_check.release()
|
||||
|
||||
|
||||
def pre_provision_tenant() -> None:
|
||||
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
@@ -185,9 +185,4 @@ def pre_provision_tenant() -> None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
try:
|
||||
lock_provision.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release provision lock (likely expired), continuing"
|
||||
)
|
||||
lock_provision.release()
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
if not MULTI_TENANT and HOOK_ENABLED:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
|
||||
@@ -30,8 +30,6 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import plaintext_file_name_for_id
|
||||
from onyx.file_store.utils import store_plaintext
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
@@ -291,33 +289,6 @@ def process_kg_commands(
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
def _get_or_extract_plaintext(
|
||||
file_id: str,
|
||||
extract_fn: Callable[[], str],
|
||||
) -> str:
|
||||
"""Load cached plaintext for a file, or extract and store it.
|
||||
|
||||
Tries to read pre-stored plaintext from the file store. On a miss,
|
||||
calls extract_fn to produce the text, then stores the result so
|
||||
future calls skip the expensive extraction.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
plaintext_key = plaintext_file_name_for_id(file_id)
|
||||
|
||||
# Try cached plaintext first.
|
||||
try:
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
if content_text:
|
||||
store_plaintext(file_id, content_text)
|
||||
return content_text
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -332,23 +303,12 @@ def load_chat_file(
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
file_id = file_descriptor["id"]
|
||||
|
||||
def _extract() -> str:
|
||||
return extract_file_text(
|
||||
try:
|
||||
content_text = extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
# Use the user_file_id as cache key when available (matches what
|
||||
# the celery indexing worker stores), otherwise fall back to the
|
||||
# file store id (covers code-interpreter-generated files, etc.).
|
||||
user_file_id_str = file_descriptor.get("user_file_id")
|
||||
cache_key = user_file_id_str or file_id
|
||||
|
||||
try:
|
||||
content_text = _get_or_extract_plaintext(cache_key, _extract)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
|
||||
@@ -157,7 +157,9 @@ def _execute_single_retrieval(
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(retrieval_function(**request_kwargs))
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
|
||||
@@ -2,7 +2,6 @@ import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -150,9 +149,6 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
@@ -1063,7 +1062,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
f"Body: {get_new_body_without_vectors(body)}\n"
|
||||
f"Search pipeline ID: {search_pipeline_id}\n"
|
||||
f"Phase took: {phase_took}\n"
|
||||
f"Profile: {json.dumps(profile, indent=2)}\n"
|
||||
f"Profile: {profile}\n"
|
||||
)
|
||||
if timed_out:
|
||||
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
|
||||
|
||||
@@ -950,86 +950,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
# "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_embedding,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
|
||||
@@ -404,170 +404,12 @@ class DocumentQuery:
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
# Explain is for scoring breakdowns.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
def get_keyword_search_query(
|
||||
query_text: str,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final keyword search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the keyword search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final keyword search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
keyword_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
keyword_search_query = (
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text, search_filters=keyword_search_filters
|
||||
)
|
||||
)
|
||||
|
||||
final_keyword_search_query: dict[str, Any] = {
|
||||
"query": keyword_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_keyword_search_query["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_keyword_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_keyword_search_query["explain"] = True
|
||||
|
||||
return final_keyword_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_semantic_search_query(
|
||||
query_embedding: list[float],
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final semantic search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the semantic search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final semantic search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
semantic_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
semantic_search_query = (
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_embedding,
|
||||
vector_candidates=num_hits,
|
||||
search_filters=semantic_search_filters,
|
||||
)
|
||||
)
|
||||
|
||||
final_semantic_search_query: dict[str, Any] = {
|
||||
"query": semantic_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_semantic_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_semantic_search_query["explain"] = True
|
||||
|
||||
return final_semantic_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_random_search_query(
|
||||
tenant_state: TenantState,
|
||||
@@ -739,9 +581,8 @@ class DocumentQuery:
|
||||
def _get_content_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
query = {
|
||||
return {
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
@@ -750,19 +591,11 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
|
||||
"bool": {"filter": search_filters}
|
||||
}
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_title_content_combined_keyword_search_query(
|
||||
query_text: str,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
query = {
|
||||
return {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
@@ -803,19 +636,10 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
},
|
||||
],
|
||||
# Ensure at least one term from the query is present in the
|
||||
# document. This defaults to 1, unless a filter or must clause
|
||||
# is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["bool"]["filter"] = search_filters
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
|
||||
@@ -88,7 +88,6 @@ class OnyxErrorCode(Enum):
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
|
||||
@@ -23,55 +23,45 @@ from onyx.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def plaintext_file_name_for_id(file_id: str) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a file."""
|
||||
return f"plaintext_{file_id}"
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
|
||||
|
||||
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""
|
||||
Store plaintext content for a file in the file store.
|
||||
Store plaintext content for a user file in the file store.
|
||||
|
||||
Args:
|
||||
file_id: The ID of the file (user_file or artifact_file)
|
||||
user_file_id: The ID of the user file
|
||||
plaintext_content: The plaintext content to store
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
plaintext_file_name = plaintext_file_name_for_id(file_id)
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for {file_id}",
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Convenience wrappers for callers that use user-file UUIDs ---
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return plaintext_file_name_for_id(str(user_file_id))
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
|
||||
return store_plaintext(str(user_file_id), plaintext_content)
|
||||
|
||||
|
||||
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
|
||||
"""Load a file directly from the file store using its file_record ID.
|
||||
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
"""No active hook configured for this hook point."""
|
||||
|
||||
|
||||
class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
)
|
||||
return outcome.response_payload
|
||||
@@ -42,8 +42,12 @@ class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = None
|
||||
timeout_seconds: float | None = Field(default=None, gt=0)
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
@@ -56,14 +60,6 @@ class HookUpdateRequest(BaseModel):
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
|
||||
raise ValueError(
|
||||
"fail_strategy cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
|
||||
raise ValueError(
|
||||
"timeout_seconds cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -94,28 +90,38 @@ class HookResponse(BaseModel):
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
is_reachable: bool | None
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateStatus(str, Enum):
|
||||
passed = "passed" # server responded (any status except 401/403)
|
||||
auth_failed = "auth_failed" # server responded with 401 or 403
|
||||
timeout = (
|
||||
"timeout" # TCP connected, but read/write timed out (server exists but slow)
|
||||
)
|
||||
cannot_connect = "cannot_connect" # could not connect to the server
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
status: HookValidateStatus
|
||||
success: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class HookExecutionRecord(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -14,25 +13,22 @@ _REQUIRED_ATTRS = (
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
"payload_model",
|
||||
"response_model",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec:
|
||||
class HookPointSpec(ABC):
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
This is NOT a regular class meant for direct instantiation by callers.
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
|
||||
get_hook_point_spec() or get_all_specs() from the registry over direct
|
||||
instantiation.
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
|
||||
should ever create instances directly — use get_hook_point_spec() or
|
||||
get_all_specs() from the registry instead.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
payload_model and response_model must be Pydantic BaseModel subclasses;
|
||||
input_schema and output_schema are derived from them automatically.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
@@ -43,33 +39,21 @@ class HookPointSpec:
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
payload_model: ClassVar[type[BaseModel]]
|
||||
response_model: ClassVar[type[BaseModel]]
|
||||
|
||||
# Computed once at class definition time from payload_model / response_model.
|
||||
input_schema: ClassVar[dict[str, Any]]
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
for attr in ("payload_model", "response_model"):
|
||||
val = getattr(cls, attr, None)
|
||||
if val is None or not (
|
||||
isinstance(val, type) and issubclass(val, BaseModel)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
|
||||
)
|
||||
cls.input_schema = cls.payload_model.model_json_schema()
|
||||
cls.output_schema = cls.response_model.model_json_schema()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
@@ -27,5 +18,12 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@@ -1,39 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(description="The raw query string exactly as the user typed it.")
|
||||
user_email: str | None = Field(
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
query: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
@@ -66,5 +37,47 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
"chat_session_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email", "chat_session_id"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -530,11 +530,6 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
optional_kwargs["tool_choice"] = tool_choice
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
@@ -543,6 +538,7 @@ class LitellmLLM(LLM):
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
|
||||
@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -454,7 +453,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -1,453 +0,0 @@
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
from onyx.db.hook import get_hook_execution_logs
|
||||
from onyx.db.hook import get_hooks
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookExecutionRecord
|
||||
from onyx.hooks.models import HookPointMetaResponse
|
||||
from onyx.hooks.models import HookResponse
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
|
||||
return HookResponse(
|
||||
id=hook.id,
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
is_reachable=hook.is_reachable,
|
||||
creator_email=(
|
||||
creator_email
|
||||
if creator_email is not None
|
||||
else (hook.creator.email if hook.creator else None)
|
||||
),
|
||||
created_at=hook.created_at,
|
||||
updated_at=hook.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_hook_or_404(
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
include_creator=include_creator,
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
|
||||
return hook
|
||||
|
||||
|
||||
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
|
||||
"""Raise an appropriate OnyxError for a non-passed validation result."""
|
||||
if validation.status == HookValidateStatus.auth_failed:
|
||||
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
|
||||
if validation.status == HookValidateStatus.timeout:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_endpoint(
|
||||
endpoint_url: str,
|
||||
api_key: str | None,
|
||||
timeout_seconds: float,
|
||||
) -> HookValidateResponse:
|
||||
"""Check whether endpoint_url is reachable by sending an empty POST request.
|
||||
|
||||
We use POST since hook endpoints expect POST requests. The server will typically
|
||||
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
|
||||
the server is up and routable. A 401/403 response returns auth_failed
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
|
||||
response = client.post(endpoint_url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed,
|
||||
error_message=f"Authentication failed (HTTP {response.status_code})",
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.timeout,
|
||||
error_message="Endpoint timed out — consider increasing timeout_seconds.",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connection error for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
HookPointMetaResponse(
|
||||
hook_point=spec.hook_point,
|
||||
display_name=spec.display_name,
|
||||
description=spec.description,
|
||||
docs_url=spec.docs_url,
|
||||
input_schema=spec.input_schema,
|
||||
output_schema=spec.output_schema,
|
||||
default_timeout_seconds=spec.default_timeout_seconds,
|
||||
default_fail_strategy=spec.default_fail_strategy,
|
||||
fail_hard_description=spec.fail_hard_description,
|
||||
)
|
||||
for spec in get_all_specs()
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
hooks = get_hooks(db_session=db_session, include_creator=True)
|
||||
return [_hook_to_response(h) for h in hooks]
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = create_hook__no_commit(
|
||||
db_session=db_session,
|
||||
name=req.name,
|
||||
hook_point=req.hook_point,
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.patch("/{hook_id}")
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
|
||||
endpoint is re-validated using the effective values. For active hooks the update is
|
||||
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
|
||||
the update goes through regardless and is_reachable is updated to reflect the result.
|
||||
|
||||
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
|
||||
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
|
||||
"""
|
||||
# api_key: UNSET = no change, None = clear, value = update
|
||||
api_key: str | None | UnsetType
|
||||
if "api_key" not in req.model_fields_set:
|
||||
api_key = UNSET
|
||||
elif req.api_key is None:
|
||||
api_key = None
|
||||
else:
|
||||
api_key = req.api_key.get_secret_value()
|
||||
|
||||
endpoint_url_changing = "endpoint_url" in req.model_fields_set
|
||||
api_key_changing = not isinstance(api_key, UnsetType)
|
||||
timeout_changing = "timeout_seconds" in req.model_fields_set
|
||||
|
||||
validated_is_reachable: bool | None = None
|
||||
if endpoint_url_changing or api_key_changing or timeout_changing:
|
||||
existing = _get_hook_or_404(db_session, hook_id)
|
||||
effective_url: str = (
|
||||
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
|
||||
)
|
||||
effective_api_key: str | None = (
|
||||
(api_key if not isinstance(api_key, UnsetType) else None)
|
||||
if api_key_changing
|
||||
else (
|
||||
existing.api_key.get_value(apply_mask=False)
|
||||
if existing.api_key
|
||||
else None
|
||||
)
|
||||
)
|
||||
effective_timeout: float = (
|
||||
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
)
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=effective_url,
|
||||
api_key=effective_api_key,
|
||||
timeout_seconds=effective_timeout,
|
||||
)
|
||||
if existing.is_active and validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
validated_is_reachable = validation.status == HookValidateStatus.passed
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
name=req.name,
|
||||
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds,
|
||||
is_reachable=validated_is_reachable,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
# Persist is_reachable=False in a separate session so the request
|
||||
# session has no commits on the failure path and the transaction
|
||||
# boundary stays clean.
|
||||
if hook.is_reachable is not False:
|
||||
with get_session_with_current_tenant() as side_session:
|
||||
update_hook__no_commit(
|
||||
db_session=side_session, hook_id=hook_id, is_reachable=False
|
||||
)
|
||||
side_session.commit()
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
validation_passed = validation.status == HookValidateStatus.passed
|
||||
if hook.is_reachable != validation_passed:
|
||||
update_hook__no_commit(
|
||||
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
|
||||
)
|
||||
db_session.commit()
|
||||
return validation
|
||||
|
||||
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=False,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution log endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{hook_id}/execution-logs")
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
_get_hook_or_404(db_session, hook_id)
|
||||
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
|
||||
return [
|
||||
HookExecutionRecord(
|
||||
error_message=log.error_message,
|
||||
status_code=log.status_code,
|
||||
duration_ms=log.duration_ms,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
@@ -140,20 +140,10 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(
|
||||
url: str,
|
||||
*,
|
||||
allow_private_network: bool = False,
|
||||
https_only: bool = False,
|
||||
) -> str:
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
allow_private_network: If True, skip private/reserved IP checks.
|
||||
https_only: If True, reject http:// URLs (only https:// is allowed).
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
@@ -167,12 +157,7 @@ def validate_outbound_http_url(
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if https_only:
|
||||
if parsed.scheme != "https":
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
|
||||
)
|
||||
elif parsed.scheme not in ("http", "https"):
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
@@ -1640,275 +1640,3 @@ class TestOpenSearchClient:
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
|
||||
def test_keyword_search(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests keyword search with filters for ACL, hidden documents, and tenant
|
||||
isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
# Tests that we don't return documents that don't match keywords at
|
||||
# all, even if they match filters.
|
||||
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-but-not-relevant-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="This text should not match the query at all",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Should not match private-but-not-relevant-doc-user-a.
|
||||
query_text = "document content"
|
||||
search_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query_text,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# This should be the highest-ranked result, as a higher percentage of
|
||||
# the content matches the query.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
assert results[1].score < results[0].score
|
||||
|
||||
def test_semantic_search(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests semantic search with filters for ACL, hidden documents, and tenant
|
||||
isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
# Make this identical to the query vector to test that this
|
||||
# result is returned first.
|
||||
content_vector=_generate_test_vector(0.6),
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
# Make this different from the query vector to test that this
|
||||
# result is returned second.
|
||||
content_vector=_generate_test_vector(0.5),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
query_vector = _generate_test_vector(0.6)
|
||||
search_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_vector,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# We explicitly expect this to be the highest-ranked result.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[0].score == 1.0
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert 0.0 < results[1].score < 1.0
|
||||
|
||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
@@ -29,9 +28,6 @@ _BACKEND_DIR = os.path.normpath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
|
||||
)
|
||||
|
||||
_DROP_SCHEMA_MAX_RETRIES = 3
|
||||
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -54,39 +50,6 @@ def _run_script(
|
||||
)
|
||||
|
||||
|
||||
def _force_drop_schema(engine: Engine, schema: str) -> None:
|
||||
"""Terminate backends using *schema* then drop it, retrying on deadlock.
|
||||
|
||||
Background Celery workers may discover test schemas (they match the
|
||||
``tenant_`` prefix) and hold locks on tables inside them. A bare
|
||||
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
|
||||
first kill their connections and retry if we still hit a deadlock.
|
||||
"""
|
||||
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT pg_terminate_backend(l.pid)
|
||||
FROM pg_locks l
|
||||
JOIN pg_class c ON c.oid = l.relation
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = :schema
|
||||
AND l.pid != pg_backend_pid()
|
||||
"""
|
||||
),
|
||||
{"schema": schema},
|
||||
)
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
return
|
||||
except Exception:
|
||||
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
|
||||
raise
|
||||
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,7 +104,9 @@ def tenant_schema_at_head(
|
||||
|
||||
yield schema
|
||||
|
||||
_force_drop_schema(engine, schema)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -158,7 +123,9 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
_force_drop_schema(engine, schema)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -183,7 +150,9 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
_force_drop_schema(engine, schema)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -141,12 +139,12 @@ def test_chat_history_csv_export(
|
||||
assert headers["Content-Type"] == "text/csv; charset=utf-8"
|
||||
assert "Content-Disposition" in headers
|
||||
|
||||
# Use csv.reader to properly handle newlines inside quoted fields
|
||||
csv_rows = list(csv.reader(io.StringIO(csv_content)))
|
||||
assert len(csv_rows) == 3 # Header + 2 QA pairs
|
||||
assert csv_rows[0][0] == "chat_session_id"
|
||||
assert "user_message" in csv_rows[0]
|
||||
assert "ai_response" in csv_rows[0]
|
||||
# Verify CSV content
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 3 # Header + 2 QA pairs
|
||||
assert "chat_session_id" in csv_content
|
||||
assert "user_message" in csv_content
|
||||
assert "ai_response" in csv_content
|
||||
assert "What was the Q1 revenue?" in csv_content
|
||||
assert "What about Q2 revenue?" in csv_content
|
||||
|
||||
@@ -158,5 +156,5 @@ def test_chat_history_csv_export(
|
||||
end_time=past_end,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
csv_rows = list(csv.reader(io.StringIO(csv_content)))
|
||||
assert len(csv_rows) == 1 # Only header, no data rows
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 1 # Only header, no data rows
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
@@ -10,10 +11,12 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
|
||||
|
||||
class IncompleteSpec(HookPointSpec):
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
# missing display_name, description, payload_model, response_model, etc.
|
||||
# missing display_name, description, etc.
|
||||
|
||||
class _Payload(BaseModel):
|
||||
pass
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
payload_model = _Payload
|
||||
response_model = _Payload
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@@ -1,541 +0,0 @@
|
||||
"""Unit tests for the hook executor."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
|
||||
|
||||
|
||||
def _make_hook(
|
||||
*,
|
||||
is_active: bool = True,
|
||||
endpoint_url: str | None = "https://hook.example.com/query",
|
||||
api_key: MagicMock | None = None,
|
||||
timeout_seconds: float = 5.0,
|
||||
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
|
||||
hook_id: int = 1,
|
||||
is_reachable: bool | None = None,
|
||||
) -> MagicMock:
|
||||
hook = MagicMock()
|
||||
hook.is_active = is_active
|
||||
hook.endpoint_url = endpoint_url
|
||||
hook.api_key = api_key
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
hook.id = hook_id
|
||||
hook.fail_strategy = fail_strategy
|
||||
hook.is_reachable = is_reachable
|
||||
return hook
|
||||
|
||||
|
||||
def _make_api_key(value: str) -> MagicMock:
|
||||
api_key = MagicMock()
|
||||
api_key.get_value.return_value = value
|
||||
return api_key
|
||||
|
||||
|
||||
def _make_response(
|
||||
*,
|
||||
status_code: int = 200,
|
||||
json_return: Any = _RESPONSE_PAYLOAD,
|
||||
json_side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a response mock with controllable json() behaviour."""
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
if json_side_effect is not None:
|
||||
response.json.side_effect = json_side_effect
|
||||
else:
|
||||
response.json.return_value = json_return
|
||||
return response
|
||||
|
||||
|
||||
def _setup_client(
|
||||
mock_client_cls: MagicMock,
|
||||
*,
|
||||
response: MagicMock | None = None,
|
||||
side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Wire up the httpx.Client mock and return the inner client.
|
||||
|
||||
If side_effect is an httpx.HTTPStatusError, it is raised from
|
||||
raise_for_status() (matching real httpx behaviour) and post() returns a
|
||||
response mock with the matching status_code set. All other exceptions are
|
||||
raised directly from post().
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
if isinstance(side_effect, httpx.HTTPStatusError):
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = side_effect.response.status_code
|
||||
error_response.raise_for_status.side_effect = side_effect
|
||||
mock_client.post = MagicMock(return_value=error_response)
|
||||
else:
|
||||
mock_client.post = MagicMock(
|
||||
side_effect=side_effect, return_value=response if not side_effect else None
|
||||
)
|
||||
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Early-exit guards (no HTTP call, no DB writes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hooks_available,hook",
|
||||
[
|
||||
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(False, None, id="hooks_not_available"),
|
||||
pytest.param(True, None, id="hook_not_found"),
|
||||
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
],
|
||||
)
|
||||
def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session: MagicMock,
|
||||
hooks_available: bool,
|
||||
hook: MagicMock | None,
|
||||
) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSkipped)
|
||||
mock_update.assert_not_called()
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Successful HTTP call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
_, update_kwargs = mock_update.call_args
|
||||
assert update_kwargs["is_reachable"] is True
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
|
||||
"""Deduplication guard: a hook already at is_reachable=True that succeeds
|
||||
must not trigger a DB write."""
|
||||
hook = _make_hook(is_reachable=True)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(json_return=["unexpected", "list"]),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-dict" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() raising must be treated as failure with SOFT strategy.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(
|
||||
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
|
||||
),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-JSON" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP failure paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception,fail_strategy,expected_type,expected_is_reachable",
|
||||
[
|
||||
# NetworkError → is_reachable=False
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="connect_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="connect_error_hard",
|
||||
),
|
||||
# 401/403 → is_reachable=False (api_key revoked)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"401",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=401, text="Unauthorized"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="auth_401_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"403",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=403, text="Forbidden"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="auth_403_hard",
|
||||
),
|
||||
# TimeoutException → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="timeout_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="timeout_hard",
|
||||
),
|
||||
# Other HTTP errors → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="http_status_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="http_status_error_hard",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_http_failure_paths(
|
||||
db_session: MagicMock,
|
||||
exception: Exception,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
expected_is_reachable: bool | None,
|
||||
) -> None:
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=exception)
|
||||
|
||||
if expected_type is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert isinstance(result, expected_type)
|
||||
|
||||
if expected_is_reachable is None:
|
||||
mock_update.assert_not_called()
|
||||
else:
|
||||
mock_update.assert_called_once()
|
||||
_, kwargs = mock_update.call_args
|
||||
assert kwargs["is_reachable"] is expected_is_reachable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization header
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_key_value,expect_auth_header",
|
||||
[
|
||||
pytest.param("secret-token", True, id="api_key_present"),
|
||||
pytest.param(None, False, id="api_key_absent"),
|
||||
],
|
||||
)
|
||||
def test_authorization_header(
|
||||
db_session: MagicMock,
|
||||
api_key_value: str | None,
|
||||
expect_auth_header: bool,
|
||||
) -> None:
|
||||
api_key = _make_api_key(api_key_value) if api_key_value else None
|
||||
hook = _make_hook(api_key=api_key)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
mock_client = _setup_client(mock_client_cls, response=_make_response())
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
_, call_kwargs = mock_client.post.call_args
|
||||
if expect_auth_header:
|
||||
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
|
||||
else:
|
||||
assert "Authorization" not in call_kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist session failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"http_exception,expected_result",
|
||||
[
|
||||
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
|
||||
],
|
||||
)
|
||||
def test_persist_session_failure_is_swallowed(
|
||||
db_session: MagicMock,
|
||||
http_exception: Exception | None,
|
||||
expected_result: Any,
|
||||
) -> None:
|
||||
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_session_with_current_tenant",
|
||||
side_effect=RuntimeError("DB unavailable"),
|
||||
),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response() if not http_exception else None,
|
||||
side_effect=http_exception,
|
||||
)
|
||||
|
||||
if expected_result is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
|
||||
"""is_reachable update failing (e.g. concurrent hook deletion) must not
|
||||
prevent the execution log from being written.
|
||||
|
||||
Simulates the production failure path: update_hook__no_commit raises
|
||||
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
|
||||
between the initial lookup and the reachable update.
|
||||
"""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.hooks.executor.update_hook__no_commit",
|
||||
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
|
||||
),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
mock_log.assert_called_once()
|
||||
@@ -37,20 +37,18 @@ def test_input_schema_query_is_string() -> None:
|
||||
|
||||
def test_input_schema_user_email_is_nullable() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
# Pydantic v2 emits anyOf for nullable fields
|
||||
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
|
||||
assert "null" in props["user_email"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_optional() -> None:
|
||||
# query defaults to None (absent = reject); not required in the schema
|
||||
def test_output_schema_query_is_required() -> None:
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "query" not in schema.get("required", [])
|
||||
assert "query" in schema["required"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_nullable() -> None:
|
||||
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
|
||||
# null means "reject the query"
|
||||
props = QueryProcessingSpec().output_schema["properties"]
|
||||
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
|
||||
assert "null" in props["query"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_rejection_message_is_optional() -> None:
|
||||
|
||||
@@ -256,6 +256,7 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
|
||||
{"role": "user", "content": "What's the weather and time in New York?"}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
@@ -411,6 +412,7 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
|
||||
{"role": "user", "content": "What's the weather and time in New York?"}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
@@ -1429,36 +1431,3 @@ def test_strip_tool_content_merges_consecutive_tool_results() -> None:
|
||||
assert "sunny 72F" in merged
|
||||
assert "tc_2" in merged
|
||||
assert "headline news" in merged
|
||||
|
||||
|
||||
def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> None:
|
||||
"""Regression test for providers (e.g. Fireworks) that reject tool_choice=null.
|
||||
|
||||
When no tools are provided, tool_choice must not be forwarded to
|
||||
litellm.completion() at all — not even as None.
|
||||
"""
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello!")]
|
||||
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello!"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
default_multi_llm.invoke(messages, tools=None)
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert (
|
||||
"tool_choice" not in kwargs
|
||||
), "tool_choice must not be sent to providers when no tools are provided"
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Unit tests for onyx.server.features.hooks.api helpers.
|
||||
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from onyx.server.features.hooks.api import _validate_endpoint
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_URL = "https://example.com/hook"
|
||||
_API_KEY = "secret"
|
||||
_TIMEOUT = 5.0
|
||||
|
||||
|
||||
def _mock_response(status_code: int) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_ssrf_safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSsrfSafety:
|
||||
def _call(self, url: str) -> None:
|
||||
_check_ssrf_safety(url)
|
||||
|
||||
# --- scheme checks ---
|
||||
|
||||
def test_https_is_allowed(self) -> None:
|
||||
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
|
||||
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
|
||||
self._call("https://example.com/hook") # must not raise
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url", ["http://example.com/hook", "ftp://example.com/hook"]
|
||||
)
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
pytest.param("127.0.0.1", id="loopback"),
|
||||
pytest.param("10.0.0.1", id="RFC1918-A"),
|
||||
pytest.param("172.16.0.1", id="RFC1918-B"),
|
||||
pytest.param("192.168.1.1", id="RFC1918-C"),
|
||||
pytest.param("169.254.169.254", id="link-local-IMDS"),
|
||||
pytest.param("100.64.0.1", id="shared-address-space"),
|
||||
pytest.param("::1", id="IPv6-loopback"),
|
||||
pytest.param("fc00::1", id="IPv6-ULA"),
|
||||
pytest.param("fe80::1", id="IPv6-link-local"),
|
||||
],
|
||||
)
|
||||
def test_private_ip_is_blocked(self, ip: str) -> None:
|
||||
with (
|
||||
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
|
||||
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
|
||||
self._call("https://example.com/hook") # must not raise
|
||||
|
||||
def test_dns_resolution_failure_raises(self) -> None:
|
||||
import socket
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.utils.url.socket.getaddrinfo",
|
||||
side_effect=socket.gaierror("name not found"),
|
||||
),
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEndpoint:
|
||||
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
|
||||
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
|
||||
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
return _validate_endpoint(
|
||||
endpoint_url=_URL,
|
||||
api_key=api_key,
|
||||
timeout_seconds=_TIMEOUT,
|
||||
)
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(200)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(500)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
def test_401_403_returns_auth_failed(
|
||||
self, mock_client_cls: MagicMock, status_code: int
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(status_code)
|
||||
)
|
||||
result = self._call()
|
||||
assert result.status == HookValidateStatus.auth_failed
|
||||
assert str(status_code) in (result.error_message or "")
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(422)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
httpx.ReadTimeout("read timeout"),
|
||||
httpx.WriteTimeout("write timeout"),
|
||||
httpx.PoolTimeout("pool timeout"),
|
||||
],
|
||||
)
|
||||
def test_read_write_pool_timeout_returns_timeout(
|
||||
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_error_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
# Covers DNS failures, TLS errors, and other connection-level errors.
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectError("name resolution failed")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_arbitrary_exception_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
ConnectionRefusedError("refused")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
self._call(api_key="mykey")
|
||||
_, kwargs = mock_post.call_args
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
self._call(api_key=None)
|
||||
_, kwargs = mock_post.call_args
|
||||
assert "Authorization" not in kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _raise_for_validation_failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRaiseForValidationFailure:
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected_code",
|
||||
[
|
||||
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
|
||||
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
|
||||
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
|
||||
],
|
||||
)
|
||||
def test_raises_correct_error_code(
|
||||
self, status: HookValidateStatus, expected_code: OnyxErrorCode
|
||||
) -> None:
|
||||
validation = HookValidateResponse(status=status, error_message="some error")
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.error_code == expected_code
|
||||
|
||||
def test_auth_failed_passes_error_message_directly(self) -> None:
|
||||
validation = HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed, error_message="bad credentials"
|
||||
)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.detail == "bad credentials"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
|
||||
)
|
||||
def test_timeout_and_cannot_connect_wrap_error_message(
|
||||
self, status: HookValidateStatus
|
||||
) -> None:
|
||||
validation = HookValidateResponse(status=status, error_message="raw error")
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_raise_for_validation_failure(validation)
|
||||
assert exc_info.value.detail == "Endpoint validation failed: raw error"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HookValidateStatus enum string values (API contract)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHookValidateStatusValues:
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected",
|
||||
[
|
||||
(HookValidateStatus.passed, "passed"),
|
||||
(HookValidateStatus.auth_failed, "auth_failed"),
|
||||
(HookValidateStatus.timeout, "timeout"),
|
||||
(HookValidateStatus.cannot_connect, "cannot_connect"),
|
||||
],
|
||||
)
|
||||
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
|
||||
assert status == expected
|
||||
93
cubic.yaml
93
cubic.yaml
@@ -1,93 +0,0 @@
|
||||
# yaml-language-server: $schema=https://cubic.dev/schema/cubic-repository-config.schema.json
|
||||
version: 1
|
||||
|
||||
reviews:
|
||||
enabled: true
|
||||
sensitivity: medium
|
||||
incremental_commits: true
|
||||
check_drafts: false
|
||||
|
||||
custom_instructions: |
|
||||
Use explicit type annotations for variables to enhance code clarity,
|
||||
especially when moving type hints around in the code.
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context.
|
||||
Prefer consistency with existing patterns, fix issues in code you touch,
|
||||
avoid tacking new features onto muddy interfaces, fail loudly instead of
|
||||
silently swallowing errors, keep code strictly typed, preserve clear state
|
||||
boundaries, remove duplicate or dead logic, break up overly long functions,
|
||||
avoid hidden import-time side effects, respect module boundaries, and favor
|
||||
correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
Reference these files for additional context:
|
||||
- `contributing_guides/best_practices.md` — Best practices for contributing to the codebase
|
||||
- `CLAUDE.md` — Project instructions and coding standards
|
||||
- `backend/alembic/README.md` — Migration guidance, including multi-tenant migration behavior
|
||||
- `deployment/helm/charts/onyx/values-lite.yaml` — Lite deployment Helm values and service assumptions
|
||||
- `deployment/docker_compose/docker-compose.onyx-lite.yml` — Lite deployment Docker Compose overlay and disabled service behavior
|
||||
|
||||
ignore:
|
||||
files:
|
||||
- greptile.json
|
||||
- cubic.yaml
|
||||
|
||||
custom_rules:
|
||||
- name: TODO format
|
||||
description: >
|
||||
Whenever a TODO is added, there must always be an associated name or
|
||||
ticket in the style of TODO(name): ... or TODO(1234): ...
|
||||
|
||||
- name: Frontend standards
|
||||
description: >
|
||||
For frontend changes, enforce all standards described in the
|
||||
web/AGENTS.md file.
|
||||
include:
|
||||
- web/**
|
||||
- desktop/**
|
||||
|
||||
- name: No debugging code
|
||||
description: >
|
||||
Remove temporary debugging code before merging to production,
|
||||
especially tenant-specific debugging logs.
|
||||
|
||||
- name: No hardcoded booleans
|
||||
description: >
|
||||
When hardcoding a boolean variable to a constant value, remove the
|
||||
variable entirely and clean up all places where it's used rather than
|
||||
just setting it to a constant.
|
||||
|
||||
- name: Multi-tenant awareness
|
||||
description: >
|
||||
Code changes must consider both multi-tenant and single-tenant
|
||||
deployments. In multi-tenant mode, preserve tenant isolation, ensure
|
||||
tenant context is propagated correctly, and avoid assumptions that only
|
||||
hold for a single shared schema or globally shared state. In
|
||||
single-tenant mode, avoid introducing unnecessary tenant-specific
|
||||
requirements or cloud-only control-plane dependencies.
|
||||
|
||||
- name: Onyx lite compatibility
|
||||
description: >
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite
|
||||
deployments. Lite deployments disable the vector DB, Redis, model
|
||||
servers, and background workers by default, use PostgreSQL-backed
|
||||
cache/auth/file storage, and rely on the API server to handle
|
||||
background work. Do not assume those services are available unless the
|
||||
code path is explicitly limited to full deployments.
|
||||
|
||||
- name: OnyxError over HTTPException
|
||||
description: >
|
||||
Never raise HTTPException directly in business code. Use
|
||||
`raise OnyxError(OnyxErrorCode.XXX, "message")` from
|
||||
`onyx.error_handling.exceptions`. A global FastAPI exception handler
|
||||
converts OnyxError into structured JSON responses with
|
||||
{"error_code": "...", "detail": "..."}. Error codes are defined in
|
||||
`onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors
|
||||
with dynamic HTTP status codes, use `status_code_override`:
|
||||
`raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`.
|
||||
include:
|
||||
- backend/**/*.py
|
||||
|
||||
issues:
|
||||
fix_with_cubic_buttons: true
|
||||
pr_comment_fixes: true
|
||||
fix_commits_to_pr: true
|
||||
@@ -489,18 +489,20 @@ services:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
|
||||
@@ -290,20 +290,25 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
# sleep a little bit to allow the web_server / api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -314,19 +314,21 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/sslcerts:/etc/nginx/sslcerts
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -333,20 +333,25 @@ services:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
# sleep a little bit to allow the web_server / api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template.prod"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
environment:
|
||||
|
||||
@@ -202,18 +202,20 @@ services:
|
||||
ports:
|
||||
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
minio:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
|
||||
@@ -477,10 +477,7 @@ services:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
# Mount templates read-only; the startup command copies them into
|
||||
# the writable /etc/nginx/conf.d/ inside the container. This avoids
|
||||
# "Permission denied" errors on Windows Docker bind mounts.
|
||||
- ../data/nginx:/nginx-templates:ro
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
|
||||
# - ../data/certbot/conf:/etc/letsencrypt
|
||||
# - ../data/certbot/www:/var/www/certbot
|
||||
@@ -492,13 +489,12 @@ services:
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not receive any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
# PRODUCTION: Change to app.conf.template.prod for production nginx config
|
||||
command: >
|
||||
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
|
||||
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
|
||||
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
|
||||
&& chmod +x /tmp/run-nginx.sh
|
||||
&& /tmp/run-nginx.sh app.conf.template"
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
cache:
|
||||
image: redis:7.4-alpine
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -96,8 +96,8 @@ fi
|
||||
|
||||
# When --lite is passed as a flag, lower resource thresholds early (before the
|
||||
# resource check). When lite is chosen interactively, the thresholds are adjusted
|
||||
# after the resource check has already passed with the standard thresholds —
|
||||
# which is the safer direction.
|
||||
# inside the new-deployment flow, after the resource check has already passed
|
||||
# with the standard thresholds — which is the safer direction.
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
@@ -110,6 +110,9 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
|
||||
# Build the -f flags for docker compose.
|
||||
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
|
||||
# (used by shutdown/delete-data so users don't need to remember --lite).
|
||||
# Without the argument, the lite overlay is only included when --lite was
|
||||
# explicitly passed — preventing install/start from silently staying in
|
||||
# lite mode just because the file exists on disk from a prior run.
|
||||
compose_file_args() {
|
||||
local auto_detect="${1:-false}"
|
||||
local args="-f docker-compose.yml"
|
||||
@@ -174,52 +177,34 @@ ensure_file() {
|
||||
|
||||
# --- Interactive prompt helpers ---
|
||||
is_interactive() {
|
||||
[[ "$NO_PROMPT" = false ]] && [[ -r /dev/tty ]] && [[ -w /dev/tty ]]
|
||||
}
|
||||
|
||||
read_prompt_line() {
|
||||
local prompt_text="$1"
|
||||
if ! is_interactive; then
|
||||
REPLY=""
|
||||
return
|
||||
fi
|
||||
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
|
||||
IFS= read -r REPLY < /dev/tty || REPLY=""
|
||||
}
|
||||
|
||||
read_prompt_char() {
|
||||
local prompt_text="$1"
|
||||
if ! is_interactive; then
|
||||
REPLY=""
|
||||
return
|
||||
fi
|
||||
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
|
||||
IFS= read -r -n 1 REPLY < /dev/tty || REPLY=""
|
||||
printf "\n" > /dev/tty
|
||||
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
|
||||
}
|
||||
|
||||
prompt_or_default() {
|
||||
local prompt_text="$1"
|
||||
local default_value="$2"
|
||||
read_prompt_line "$prompt_text"
|
||||
[[ -z "$REPLY" ]] && REPLY="$default_value"
|
||||
if is_interactive; then
|
||||
read -p "$prompt_text" -r REPLY
|
||||
if [[ -z "$REPLY" ]]; then
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
else
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
}
|
||||
|
||||
prompt_yn_or_default() {
|
||||
local prompt_text="$1"
|
||||
local default_value="$2"
|
||||
read_prompt_char "$prompt_text"
|
||||
[[ -z "$REPLY" ]] && REPLY="$default_value"
|
||||
}
|
||||
|
||||
confirm_action() {
|
||||
local description="$1"
|
||||
prompt_yn_or_default "Install ${description}? (Y/n) [default: Y] " "Y"
|
||||
if [[ "$REPLY" =~ ^[Nn] ]]; then
|
||||
print_warning "Skipping: ${description}"
|
||||
return 1
|
||||
if is_interactive; then
|
||||
read -p "$prompt_text" -n 1 -r
|
||||
echo ""
|
||||
if [[ -z "$REPLY" ]]; then
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
else
|
||||
REPLY="$default_value"
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
# Colors for output
|
||||
@@ -310,8 +295,8 @@ if [ "$DELETE_DATA_MODE" = true ]; then
|
||||
echo " • All user data and documents"
|
||||
echo ""
|
||||
if is_interactive; then
|
||||
prompt_or_default "Are you sure you want to continue? Type 'DELETE' to confirm: " ""
|
||||
echo "" > /dev/tty
|
||||
read -p "Are you sure you want to continue? Type 'DELETE' to confirm: " -r
|
||||
echo ""
|
||||
if [ "$REPLY" != "DELETE" ]; then
|
||||
print_info "Operation cancelled."
|
||||
exit 0
|
||||
@@ -410,11 +395,6 @@ fi
|
||||
|
||||
if ! command -v docker &> /dev/null; then
|
||||
if [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; then
|
||||
print_info "Docker is required but not installed."
|
||||
if ! confirm_action "Docker Engine"; then
|
||||
print_error "Docker is required to run Onyx."
|
||||
exit 1
|
||||
fi
|
||||
install_docker_linux
|
||||
if ! command -v docker &> /dev/null; then
|
||||
print_error "Docker installation failed."
|
||||
@@ -431,11 +411,7 @@ if command -v docker &> /dev/null \
|
||||
&& ! command -v docker-compose &> /dev/null \
|
||||
&& { [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; }; then
|
||||
|
||||
print_info "Docker Compose is required but not installed."
|
||||
if ! confirm_action "Docker Compose plugin"; then
|
||||
print_error "Docker Compose is required to run Onyx."
|
||||
exit 1
|
||||
fi
|
||||
print_info "Docker Compose not found — installing plugin..."
|
||||
COMPOSE_ARCH="$(uname -m)"
|
||||
COMPOSE_URL="https://github.com/docker/compose/releases/latest/download/docker-compose-linux-${COMPOSE_ARCH}"
|
||||
COMPOSE_DIR="/usr/local/lib/docker/cli-plugins"
|
||||
@@ -505,7 +481,7 @@ echo ""
|
||||
|
||||
if is_interactive; then
|
||||
echo -e "${YELLOW}${BOLD}Please acknowledge and press Enter to continue...${NC}"
|
||||
read_prompt_line ""
|
||||
read -r
|
||||
echo ""
|
||||
else
|
||||
echo -e "${YELLOW}${BOLD}Running in non-interactive mode - proceeding automatically...${NC}"
|
||||
@@ -586,31 +562,10 @@ version_compare() {
|
||||
|
||||
# Check Docker daemon
|
||||
if ! docker info &> /dev/null; then
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
print_info "Docker daemon is not running. Starting Docker Desktop..."
|
||||
open -a Docker
|
||||
# Wait up to 120 seconds for Docker to be ready
|
||||
DOCKER_WAIT=0
|
||||
DOCKER_MAX_WAIT=120
|
||||
while ! docker info &> /dev/null; do
|
||||
if [ $DOCKER_WAIT -ge $DOCKER_MAX_WAIT ]; then
|
||||
print_error "Docker Desktop did not start within ${DOCKER_MAX_WAIT} seconds."
|
||||
print_info "Please start Docker Desktop manually and re-run this script."
|
||||
exit 1
|
||||
fi
|
||||
printf "\r\033[KWaiting for Docker Desktop to start... (%ds)" "$DOCKER_WAIT"
|
||||
sleep 2
|
||||
DOCKER_WAIT=$((DOCKER_WAIT + 2))
|
||||
done
|
||||
echo ""
|
||||
print_success "Docker Desktop is now running"
|
||||
else
|
||||
print_error "Docker daemon is not running. Please start Docker."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
print_success "Docker daemon is running"
|
||||
print_error "Docker daemon is not running. Please start Docker."
|
||||
exit 1
|
||||
fi
|
||||
print_success "Docker daemon is running"
|
||||
|
||||
# Check Docker resources
|
||||
print_step "Verifying Docker resources"
|
||||
@@ -750,48 +705,25 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
|
||||
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
|
||||
fi
|
||||
|
||||
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
|
||||
if [[ "$LITE_MODE" = false ]]; then
|
||||
print_info "Which deployment mode would you like?"
|
||||
echo ""
|
||||
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
echo " LLM chat, tools, file uploads, and Projects still work"
|
||||
echo " 2) Standard - Full deployment with search, connectors, and RAG"
|
||||
echo ""
|
||||
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
|
||||
echo ""
|
||||
|
||||
case "$REPLY" in
|
||||
2)
|
||||
print_info "Selected: Standard mode"
|
||||
;;
|
||||
*)
|
||||
LITE_MODE=true
|
||||
print_info "Selected: Lite mode"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
print_info "Deployment mode: Lite (set via --lite flag)"
|
||||
fi
|
||||
|
||||
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
|
||||
print_error "--include-craft cannot be used with Lite mode."
|
||||
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
fi
|
||||
|
||||
# Handle lite overlay file based on selected mode
|
||||
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
|
||||
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
|
||||
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed previous lite overlay (switching to standard mode)"
|
||||
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
|
||||
print_warning "Existing lite overlay found but --lite was not passed."
|
||||
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
|
||||
LITE_MODE=true
|
||||
else
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed lite overlay (switching to standard mode)"
|
||||
fi
|
||||
else
|
||||
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
|
||||
print_info "Removed previous lite overlay (switching to standard mode)"
|
||||
fi
|
||||
fi
|
||||
|
||||
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
|
||||
@@ -813,7 +745,6 @@ print_success "All configuration files ready"
|
||||
# Set up deployment configuration
|
||||
print_step "Setting up deployment configs"
|
||||
ENV_FILE="${INSTALL_ROOT}/deployment/.env"
|
||||
ENV_TEMPLATE="${INSTALL_ROOT}/deployment/env.template"
|
||||
# Check if services are already running
|
||||
if [ -d "${INSTALL_ROOT}/deployment" ] && [ -f "${INSTALL_ROOT}/deployment/docker-compose.yml" ]; then
|
||||
# Determine compose command
|
||||
@@ -854,22 +785,22 @@ if [ -f "$ENV_FILE" ]; then
|
||||
if [ "$REPLY" = "update" ]; then
|
||||
print_info "Update selected. Which tag would you like to deploy?"
|
||||
echo ""
|
||||
echo "• Press Enter for edge (recommended)"
|
||||
echo "• Press Enter for latest (recommended)"
|
||||
echo "• Type a specific tag (e.g., v0.1.0)"
|
||||
echo ""
|
||||
if [ "$INCLUDE_CRAFT" = true ]; then
|
||||
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
|
||||
VERSION="$REPLY"
|
||||
else
|
||||
prompt_or_default "Enter tag [default: edge]: " "edge"
|
||||
prompt_or_default "Enter tag [default: latest]: " "latest"
|
||||
VERSION="$REPLY"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
|
||||
print_info "Selected: craft-latest (Craft enabled)"
|
||||
elif [ "$VERSION" = "edge" ]; then
|
||||
print_info "Selected: edge (latest nightly)"
|
||||
elif [ "$VERSION" = "latest" ]; then
|
||||
print_info "Selected: Latest version"
|
||||
else
|
||||
print_info "Selected: $VERSION"
|
||||
fi
|
||||
@@ -921,6 +852,45 @@ else
|
||||
print_info "No existing .env file found. Setting up new deployment..."
|
||||
echo ""
|
||||
|
||||
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
|
||||
if [[ "$LITE_MODE" = false ]]; then
|
||||
print_info "Which deployment mode would you like?"
|
||||
echo ""
|
||||
echo " 1) Standard - Full deployment with search, connectors, and RAG"
|
||||
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
|
||||
echo " LLM chat, tools, file uploads, and Projects still work"
|
||||
echo ""
|
||||
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
|
||||
echo ""
|
||||
|
||||
case "$REPLY" in
|
||||
2)
|
||||
LITE_MODE=true
|
||||
print_info "Selected: Lite mode"
|
||||
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
|
||||
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
|
||||
;;
|
||||
*)
|
||||
print_info "Selected: Standard mode"
|
||||
;;
|
||||
esac
|
||||
else
|
||||
print_info "Deployment mode: Lite (set via --lite flag)"
|
||||
fi
|
||||
|
||||
# Validate lite + craft combination (could now be set interactively)
|
||||
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
|
||||
print_error "--include-craft cannot be used with Lite mode."
|
||||
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Adjust resource expectations for lite mode
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
EXPECTED_DOCKER_RAM_GB=4
|
||||
EXPECTED_DISK_GB=16
|
||||
fi
|
||||
|
||||
# Ask for version
|
||||
print_info "Which tag would you like to deploy?"
|
||||
echo ""
|
||||
@@ -931,18 +901,18 @@ else
|
||||
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
|
||||
VERSION="$REPLY"
|
||||
else
|
||||
echo "• Press Enter for edge (recommended)"
|
||||
echo "• Press Enter for latest (recommended)"
|
||||
echo "• Type a specific tag (e.g., v0.1.0)"
|
||||
echo ""
|
||||
prompt_or_default "Enter tag [default: edge]: " "edge"
|
||||
prompt_or_default "Enter tag [default: latest]: " "latest"
|
||||
VERSION="$REPLY"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
|
||||
print_info "Selected: craft-latest (Craft enabled)"
|
||||
elif [ "$VERSION" = "edge" ]; then
|
||||
print_info "Selected: edge (latest nightly)"
|
||||
elif [ "$VERSION" = "latest" ]; then
|
||||
print_info "Selected: Latest tag"
|
||||
else
|
||||
print_info "Selected: $VERSION"
|
||||
fi
|
||||
@@ -1100,39 +1070,20 @@ fi
|
||||
export HOST_PORT=$AVAILABLE_PORT
|
||||
print_success "Using port $AVAILABLE_PORT for nginx"
|
||||
|
||||
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
|
||||
# Determine if we're using the latest tag or a craft tag (both should force pull)
|
||||
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
|
||||
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
|
||||
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
USE_LATEST=true
|
||||
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
|
||||
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
|
||||
else
|
||||
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
|
||||
print_info "Using 'latest' tag - will force pull and recreate containers"
|
||||
fi
|
||||
else
|
||||
USE_LATEST=false
|
||||
fi
|
||||
|
||||
# For pinned version tags, re-download config files from that tag so the
|
||||
# compose file matches the images being pulled (the initial download used main).
|
||||
if [[ "$USE_LATEST" = false ]] && [[ "$USE_LOCAL_FILES" = false ]]; then
|
||||
PINNED_BASE="https://raw.githubusercontent.com/onyx-dot-app/onyx/${CURRENT_IMAGE_TAG}/deployment"
|
||||
print_info "Fetching config files matching tag ${CURRENT_IMAGE_TAG}..."
|
||||
if download_file "${PINNED_BASE}/docker_compose/docker-compose.yml" "${INSTALL_ROOT}/deployment/docker-compose.yml" 2>/dev/null; then
|
||||
download_file "${PINNED_BASE}/data/nginx/app.conf.template" "${INSTALL_ROOT}/data/nginx/app.conf.template" 2>/dev/null || true
|
||||
download_file "${PINNED_BASE}/data/nginx/run-nginx.sh" "${INSTALL_ROOT}/data/nginx/run-nginx.sh" 2>/dev/null || true
|
||||
chmod +x "${INSTALL_ROOT}/data/nginx/run-nginx.sh"
|
||||
if [[ "$LITE_MODE" = true ]]; then
|
||||
download_file "${PINNED_BASE}/docker_compose/${LITE_COMPOSE_FILE}" \
|
||||
"${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" 2>/dev/null || true
|
||||
fi
|
||||
print_success "Config files updated to match ${CURRENT_IMAGE_TAG}"
|
||||
else
|
||||
print_warning "Tag ${CURRENT_IMAGE_TAG} not found on GitHub — using main branch configs"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Pull Docker images with reduced output
|
||||
print_step "Pulling Docker images"
|
||||
print_info "This may take several minutes depending on your internet connection..."
|
||||
|
||||
@@ -127,7 +127,6 @@ Inputs (common):
|
||||
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
|
||||
- `postgres_username`, `postgres_password`
|
||||
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
|
||||
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
|
||||
|
||||
### `vpc`
|
||||
- Builds a VPC sized for EKS with multiple private and public subnets
|
||||
|
||||
@@ -88,8 +88,6 @@ module "waf" {
|
||||
tags = local.merged_tags
|
||||
|
||||
# WAF configuration with sensible defaults
|
||||
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
|
||||
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
|
||||
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
|
||||
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
|
||||
geo_restriction_countries = var.waf_geo_restriction_countries
|
||||
|
||||
@@ -117,18 +117,6 @@ variable "waf_rate_limit_requests_per_5_minutes" {
|
||||
default = 2000
|
||||
}
|
||||
|
||||
variable "waf_allowed_ip_cidrs" {
|
||||
type = list(string)
|
||||
description = "Optional IPv4 CIDR ranges allowed through the WAF. Leave empty to disable IP allowlisting."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "waf_common_rule_set_count_rules" {
|
||||
type = list(string)
|
||||
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "waf_api_rate_limit_requests_per_5_minutes" {
|
||||
type = number
|
||||
description = "Rate limit for API requests per 5 minutes per IP address"
|
||||
|
||||
@@ -1,20 +1,6 @@
|
||||
locals {
|
||||
name = var.name
|
||||
tags = var.tags
|
||||
ip_allowlist_enabled = length(var.allowed_ip_cidrs) > 0
|
||||
managed_rule_priority = local.ip_allowlist_enabled ? 1 : 0
|
||||
}
|
||||
|
||||
resource "aws_wafv2_ip_set" "allowed_ips" {
|
||||
count = local.ip_allowlist_enabled ? 1 : 0
|
||||
|
||||
name = "${local.name}-allowed-ips"
|
||||
description = "IP allowlist for ${local.name}"
|
||||
scope = "REGIONAL"
|
||||
ip_address_version = "IPV4"
|
||||
addresses = var.allowed_ip_cidrs
|
||||
|
||||
tags = local.tags
|
||||
name = var.name
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# AWS WAFv2 Web ACL
|
||||
@@ -27,38 +13,10 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
allow {}
|
||||
}
|
||||
|
||||
dynamic "rule" {
|
||||
for_each = local.ip_allowlist_enabled ? [1] : []
|
||||
content {
|
||||
name = "BlockRequestsOutsideAllowedIPs"
|
||||
priority = 1
|
||||
|
||||
action {
|
||||
block {}
|
||||
}
|
||||
|
||||
statement {
|
||||
not_statement {
|
||||
statement {
|
||||
ip_set_reference_statement {
|
||||
arn = aws_wafv2_ip_set.allowed_ips[0].arn
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
visibility_config {
|
||||
cloudwatch_metrics_enabled = true
|
||||
metric_name = "BlockRequestsOutsideAllowedIPsMetric"
|
||||
sampled_requests_enabled = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# AWS Managed Rules - Core Rule Set
|
||||
rule {
|
||||
name = "AWSManagedRulesCommonRuleSet"
|
||||
priority = 1 + local.managed_rule_priority
|
||||
priority = 1
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -68,16 +26,6 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
managed_rule_group_statement {
|
||||
name = "AWSManagedRulesCommonRuleSet"
|
||||
vendor_name = "AWS"
|
||||
|
||||
dynamic "rule_action_override" {
|
||||
for_each = var.common_rule_set_count_rules
|
||||
content {
|
||||
name = rule_action_override.value
|
||||
action_to_use {
|
||||
count {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +39,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# AWS Managed Rules - Known Bad Inputs
|
||||
rule {
|
||||
name = "AWSManagedRulesKnownBadInputsRuleSet"
|
||||
priority = 2 + local.managed_rule_priority
|
||||
priority = 2
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -114,7 +62,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# Rate Limiting Rule
|
||||
rule {
|
||||
name = "RateLimitRule"
|
||||
priority = 3 + local.managed_rule_priority
|
||||
priority = 3
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -139,7 +87,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
for_each = length(var.geo_restriction_countries) > 0 ? [1] : []
|
||||
content {
|
||||
name = "GeoRestrictionRule"
|
||||
priority = 4 + local.managed_rule_priority
|
||||
priority = 4
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -162,7 +110,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# IP Rate Limiting
|
||||
rule {
|
||||
name = "APIRateLimitRule"
|
||||
priority = 5 + local.managed_rule_priority
|
||||
priority = 5
|
||||
|
||||
action {
|
||||
block {}
|
||||
@@ -185,7 +133,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# SQL Injection Protection
|
||||
rule {
|
||||
name = "AWSManagedRulesSQLiRuleSet"
|
||||
priority = 6 + local.managed_rule_priority
|
||||
priority = 6
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
@@ -208,7 +156,7 @@ resource "aws_wafv2_web_acl" "main" {
|
||||
# Anonymous IP Protection
|
||||
rule {
|
||||
name = "AWSManagedRulesAnonymousIpList"
|
||||
priority = 7 + local.managed_rule_priority
|
||||
priority = 7
|
||||
|
||||
override_action {
|
||||
none {}
|
||||
|
||||
@@ -9,18 +9,6 @@ variable "tags" {
|
||||
default = {}
|
||||
}
|
||||
|
||||
variable "allowed_ip_cidrs" {
|
||||
type = list(string)
|
||||
description = "Optional IPv4 CIDR ranges allowed to reach the application. Leave empty to disable IP allowlisting."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "common_rule_set_count_rules" {
|
||||
type = list(string)
|
||||
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
|
||||
default = []
|
||||
}
|
||||
|
||||
variable "rate_limit_requests_per_5_minutes" {
|
||||
type = number
|
||||
description = "Rate limit for requests per 5 minutes per IP address"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../web/AGENTS.md
|
||||
@@ -1 +0,0 @@
|
||||
AGENTS.md
|
||||
@@ -65,7 +65,7 @@
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/STANDARDS.md file."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
@@ -85,7 +85,7 @@
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
type RunCIOptions struct {
|
||||
DryRun bool
|
||||
Yes bool
|
||||
Rerun bool
|
||||
}
|
||||
|
||||
// NewRunCICommand creates a new run-ci command
|
||||
@@ -50,7 +49,6 @@ Example usage:
|
||||
|
||||
cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, "Perform all local operations but skip pushing to remote and creating PRs")
|
||||
cmd.Flags().BoolVar(&opts.Yes, "yes", false, "Skip confirmation prompts and automatically proceed")
|
||||
cmd.Flags().BoolVar(&opts.Rerun, "rerun", false, "Update an existing CI PR with the latest fork changes to re-trigger CI")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -109,44 +107,19 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
log.Fatalf("PR #%s is not from a fork - CI should already run automatically", prNumber)
|
||||
}
|
||||
|
||||
// Create the CI branch
|
||||
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
|
||||
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
|
||||
// Check if a CI PR already exists for this branch
|
||||
existingPRURL, err := findExistingCIPR(ciBranch)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to check for existing CI PR: %v", err)
|
||||
}
|
||||
|
||||
if existingPRURL != "" && !opts.Rerun {
|
||||
log.Infof("A CI PR already exists for #%s: %s", prNumber, existingPRURL)
|
||||
log.Info("Run with --rerun to update it with the latest fork changes and re-trigger CI.")
|
||||
return
|
||||
}
|
||||
|
||||
if opts.Rerun && existingPRURL == "" {
|
||||
log.Warn("--rerun was specified but no existing open CI PR was found. A new PR will be created.")
|
||||
}
|
||||
|
||||
if existingPRURL != "" && opts.Rerun {
|
||||
log.Infof("Existing CI PR found: %s", existingPRURL)
|
||||
log.Info("Will update the CI branch with the latest fork changes to re-trigger CI.")
|
||||
}
|
||||
|
||||
// Confirm before proceeding
|
||||
if !opts.Yes {
|
||||
action := "Create CI branch"
|
||||
if existingPRURL != "" {
|
||||
action = "Update existing CI branch"
|
||||
}
|
||||
if !prompt.Confirm(fmt.Sprintf("%s for PR #%s? (yes/no): ", action, prNumber)) {
|
||||
if !prompt.Confirm(fmt.Sprintf("Create CI branch for PR #%s? (yes/no): ", prNumber)) {
|
||||
log.Info("Exiting...")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create the CI branch
|
||||
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
|
||||
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
|
||||
// Fetch the fork's branch
|
||||
if forkRepo == "" {
|
||||
log.Fatalf("Could not determine fork repository - headRepositoryOwner or headRepository.name is empty")
|
||||
@@ -185,11 +158,7 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
|
||||
if opts.DryRun {
|
||||
log.Warnf("[DRY RUN] Would push CI branch: %s", ciBranch)
|
||||
if existingPRURL == "" {
|
||||
log.Warnf("[DRY RUN] Would create PR: %s", prTitle)
|
||||
} else {
|
||||
log.Warnf("[DRY RUN] Would update existing PR: %s", existingPRURL)
|
||||
}
|
||||
log.Warnf("[DRY RUN] Would create PR: %s", prTitle)
|
||||
// Switch back to original branch
|
||||
if err := git.RunCommand("switch", "--quiet", originalBranch); err != nil {
|
||||
log.Warnf("Failed to switch back to original branch: %v", err)
|
||||
@@ -207,17 +176,6 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
log.Fatalf("Failed to push CI branch: %v", err)
|
||||
}
|
||||
|
||||
if existingPRURL != "" {
|
||||
// PR already exists - force push is enough to re-trigger CI
|
||||
log.Infof("Switching back to original branch: %s", originalBranch)
|
||||
if err := git.RunCommand("switch", "--quiet", originalBranch); err != nil {
|
||||
log.Warnf("Failed to switch back to original branch: %v", err)
|
||||
}
|
||||
log.Infof("CI PR updated successfully: %s", existingPRURL)
|
||||
log.Info("The force push will re-trigger CI. Remember to close (not merge) this PR after CI completes!")
|
||||
return
|
||||
}
|
||||
|
||||
// Create PR using GitHub CLI
|
||||
log.Info("Creating PR...")
|
||||
prURL, err := createCIPR(ciBranch, prInfo.BaseRefName, prTitle, prBody)
|
||||
@@ -259,39 +217,6 @@ func getPRInfo(prNumber string) (*PRInfo, error) {
|
||||
return &prInfo, nil
|
||||
}
|
||||
|
||||
// findExistingCIPR checks if an open PR already exists for the given CI branch.
|
||||
// Returns the PR URL if found, or empty string if not.
|
||||
func findExistingCIPR(headBranch string) (string, error) {
|
||||
cmd := exec.Command("gh", "pr", "list",
|
||||
"--head", headBranch,
|
||||
"--state", "open",
|
||||
"--json", "url",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("%w: %s", err, string(exitErr.Stderr))
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
var prs []struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
if err := json.Unmarshal(output, &prs); err != nil {
|
||||
log.Debugf("Failed to parse PR list JSON: %v (raw: %s)", err, string(output))
|
||||
return "", fmt.Errorf("failed to parse PR list: %w", err)
|
||||
}
|
||||
|
||||
if len(prs) == 0 {
|
||||
log.Debugf("No existing open PRs found for branch %s", headBranch)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
log.Debugf("Found existing PR for branch %s: %s", headBranch, prs[0].URL)
|
||||
return prs[0].URL, nil
|
||||
}
|
||||
|
||||
// createCIPR creates a pull request for CI using the GitHub CLI
|
||||
func createCIPR(headBranch, baseBranch, title, body string) (string, error) {
|
||||
cmd := exec.Command("gh", "pr", "create",
|
||||
|
||||
540
web/AGENTS.md
540
web/AGENTS.md
@@ -1,540 +0,0 @@
|
||||
# Frontend Standards
|
||||
|
||||
This file is the single source of truth for frontend coding standards across all Onyx frontend
|
||||
projects (including, but not limited to, `/web`, `/desktop`).
|
||||
|
||||
# Components
|
||||
|
||||
UI components are spread across several directories while the codebase migrates to Opal:
|
||||
|
||||
- **`web/lib/opal/src/`** — The Opal design system. Preferred for all new components.
|
||||
- **`web/src/refresh-components/`** — Production components not yet migrated to Opal.
|
||||
- **`web/src/sections/`** — Feature-specific composite components (cards, modals, etc.).
|
||||
- **`web/src/layouts/`** — Page-level layout components (settings pages, etc.).
|
||||
|
||||
**Do NOT use anything from `web/src/components/`** — this directory contains legacy components
|
||||
that are being phased out. Always prefer Opal first; fall back to `refresh-components` only for
|
||||
components not yet available in Opal.
|
||||
|
||||
## Opal Layouts (`lib/opal/src/layouts/`)
|
||||
|
||||
All layout primitives are imported from `@opal/layouts`. They handle sizing, font selection, icon
|
||||
alignment, and optional inline editing.
|
||||
|
||||
```typescript
|
||||
import { Content, ContentAction, IllustrationContent } from "@opal/layouts";
|
||||
```
|
||||
|
||||
### Content
|
||||
|
||||
**Use this for any combination of icon + title + description.**
|
||||
|
||||
A two-axis layout component that automatically routes to the correct internal layout
|
||||
(`ContentXl`, `ContentLg`, `ContentMd`, `ContentSm`) based on `sizePreset` and `variant`:
|
||||
|
||||
| sizePreset | variant | Routes to | Layout |
|
||||
|---|---|---|---|
|
||||
| `headline` / `section` | `heading` | `ContentXl` | Icon on top (flex-col) |
|
||||
| `headline` / `section` | `section` | `ContentLg` | Icon inline (flex-row) |
|
||||
| `main-content` / `main-ui` / `secondary` | `section` / `heading` | `ContentMd` | Compact inline |
|
||||
| `main-content` / `main-ui` / `secondary` | `body` | `ContentSm` | Body text layout |
|
||||
|
||||
```typescript
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgSettings}
|
||||
title="Settings"
|
||||
description="Manage your preferences"
|
||||
/>
|
||||
```
|
||||
|
||||
### ContentAction
|
||||
|
||||
**Use this when a Content block needs right-side actions** (buttons, badges, icons, etc.).
|
||||
|
||||
Wraps `Content` and adds a `rightChildren` slot. Accepts all `Content` props plus:
|
||||
- `rightChildren`: `ReactNode` — actions rendered on the right
|
||||
- `paddingVariant`: `SizeVariant` — controls outer padding
|
||||
|
||||
```typescript
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgUser}
|
||||
title="John Doe"
|
||||
description="Admin"
|
||||
rightChildren={<Button icon={SvgEdit}>Edit</Button>}
|
||||
/>
|
||||
```
|
||||
|
||||
### IllustrationContent
|
||||
|
||||
**Use this for empty states, error pages, and informational placeholders.**
|
||||
|
||||
A vertically-stacked, center-aligned layout that pairs a large illustration (7.5rem x 7.5rem)
|
||||
with a title and optional description.
|
||||
|
||||
```typescript
|
||||
import SvgNoResult from "@opal/illustrations/no-result";
|
||||
|
||||
<IllustrationContent
|
||||
illustration={SvgNoResult}
|
||||
title="No results found"
|
||||
description="Try adjusting your search or filters."
|
||||
/>
|
||||
```
|
||||
|
||||
Props:
|
||||
- `illustration`: `IconFunctionComponent` — optional, from `@opal/illustrations`
|
||||
- `title`: `string` — required
|
||||
- `description`: `string` — optional
|
||||
|
||||
## Settings Page Layout (`src/layouts/settings-layouts.tsx`)
|
||||
|
||||
**Use this for all admin/settings pages.** Provides a standardized layout with scroll-aware
|
||||
sticky headers, centered content containers, and responsive behavior.
|
||||
|
||||
```typescript
|
||||
import SettingsLayouts from "@/layouts/settings-layouts";
|
||||
|
||||
function MySettingsPage() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSettings}
|
||||
title="Account Settings"
|
||||
description="Manage your account preferences"
|
||||
rightChildren={<Button>Save</Button>}
|
||||
>
|
||||
<InputTypeIn placeholder="Search settings..." />
|
||||
</SettingsLayouts.Header>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<Card>Settings content here</Card>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
Sub-components:
|
||||
- **`SettingsLayouts.Root`** — Wrapper with centered, scrollable container. Width options:
|
||||
`"sm"` (672px), `"sm-md"` (752px), `"md"` (872px, default), `"lg"` (992px), `"full"` (100%).
|
||||
- **`SettingsLayouts.Header`** — Sticky header with icon, title, description, optional
|
||||
`rightChildren` actions, optional `children` below (e.g., search/filter), optional `backButton`,
|
||||
and optional `separator`. Automatically shows a scroll shadow when scrolled.
|
||||
- **`SettingsLayouts.Body`** — Content container with consistent padding and vertical spacing.
|
||||
|
||||
## Cards (`src/sections/cards/`)
|
||||
|
||||
**When building a card that displays information about a specific entity (agent, document set,
|
||||
file, connector, etc.), add it to `web/src/sections/cards/`.**
|
||||
|
||||
Each card is a self-contained component focused on a single entity type. Cards typically include
|
||||
entity identification (name, avatar, icon), summary information, and quick actions.
|
||||
|
||||
```typescript
|
||||
import AgentCard from "@/sections/cards/AgentCard";
|
||||
import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import FileCard from "@/sections/cards/FileCard";
|
||||
```
|
||||
|
||||
Guidelines:
|
||||
- One card per entity type — keep card-specific logic within the card component.
|
||||
- Cards should be reusable across different pages and contexts.
|
||||
- Use shared components from `@opal/components`, `@opal/layouts`, and `@/refresh-components`
|
||||
inside cards — do not duplicate layout or styling logic.
|
||||
|
||||
## Button (`components/buttons/button/`)
|
||||
|
||||
**Always use the Opal `Button`.** Do not use raw `<button>` elements.
|
||||
|
||||
Built on `Interactive.Stateless` > `Interactive.Container`, so it inherits the full color/state
|
||||
system automatically.
|
||||
|
||||
```typescript
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
|
||||
// Labeled button
|
||||
<Button variant="default" prominence="primary" icon={SvgPlus}>
|
||||
Create
|
||||
</Button>
|
||||
|
||||
// Icon-only button (omit children)
|
||||
<Button variant="default" prominence="tertiary" icon={SvgTrash} size="sm" />
|
||||
```
|
||||
|
||||
Key props:
|
||||
- `variant`: `"default"` | `"action"` | `"danger"` | `"none"`
|
||||
- `prominence`: `"primary"` | `"secondary"` | `"tertiary"` | `"internal"`
|
||||
- `size`: `"lg"` | `"md"` | `"sm"` | `"xs"` | `"2xs"` | `"fit"`
|
||||
- `icon`, `rightIcon`, `children`, `disabled`, `href`, `tooltip`
|
||||
|
||||
## Core Primitives (`core/`)
|
||||
|
||||
The `core/` directory contains the lowest-level building blocks that power all Opal components.
|
||||
**Most code should not interface with these directly** — use higher-level components like `Button`,
|
||||
`Content`, and `ContentAction` instead. These are documented here for understanding, not everyday use.
|
||||
|
||||
### Interactive (`core/interactive/`)
|
||||
|
||||
The foundational layer for all clickable/interactive surfaces. Defines the color matrix for
|
||||
hover, active, and disabled states.
|
||||
|
||||
- **`Interactive.Stateless`** — Color system for stateless elements (buttons, links). Applies
|
||||
variant/prominence/state combinations via CSS custom properties.
|
||||
- **`Interactive.Stateful`** — Color system for stateful elements (toggles, sidebar items, selects).
|
||||
Uses `state` (`"empty"` | `"filled"` | `"selected"`) instead of prominence.
|
||||
- **`Interactive.Container`** — Structural box providing height, rounding, padding, and border.
|
||||
Shared by both Stateless and Stateful. Renders as `<div>`, `<button>`, or `<Link>` depending
|
||||
on context.
|
||||
- **`Interactive.Foldable`** — Zero-width collapsible wrapper with CSS grid animation.
|
||||
|
||||
### Disabled (`core/disabled/`)
|
||||
|
||||
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
|
||||
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
|
||||
interactive descendants.
|
||||
|
||||
### Hoverable (`core/animations/`)
|
||||
|
||||
A standardized way to provide "opacity-100 on hover" behavior. Instead of manually wiring
|
||||
`opacity-0 group-hover:opacity-100` with Tailwind, use `Hoverable` for consistent, coordinated
|
||||
hover-to-reveal patterns.
|
||||
|
||||
- **`Hoverable.Root`** — Wraps a hover group. Tracks mouse enter/leave and broadcasts hover
|
||||
state to descendants via a per-group React context.
|
||||
- **`Hoverable.Item`** — Marks an element that should appear on hover. Supports two modes:
|
||||
- **Group mode** (`group` prop provided): visibility driven by a matching `Hoverable.Root`
|
||||
ancestor. Throws if no matching Root is found.
|
||||
- **Local mode** (`group` omitted): uses CSS `:hover` on the item itself.
|
||||
|
||||
```typescript
|
||||
import { Hoverable } from "@opal/core";
|
||||
|
||||
// Group mode — hovering anywhere on the row reveals the trash icon
|
||||
<Hoverable.Root group="row">
|
||||
<div className="flex items-center gap-2">
|
||||
<span>Row content</span>
|
||||
<Hoverable.Item group="row" variant="opacity-on-hover">
|
||||
<SvgTrash />
|
||||
</Hoverable.Item>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
|
||||
// Local mode — hovering the item itself reveals it
|
||||
<Hoverable.Item variant="opacity-on-hover">
|
||||
<SvgTrash />
|
||||
</Hoverable.Item>
|
||||
```
|
||||
|
||||
# Best Practices
|
||||
|
||||
## 1. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 2. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
## 3. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 5. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
## 6. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
# Stylistic Preferences
|
||||
|
||||
## 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
## 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions. Keep prop interfaces in the same file
|
||||
as the component they belong to. Non-prop types (shared models, API response shapes, enums, etc.)
|
||||
should be placed in a co-located `interfaces.ts` file.**
|
||||
|
||||
**Reason:** Prop interfaces are tightly coupled to their component and rarely imported elsewhere,
|
||||
so co-location keeps things simple. Shared types belong in `interfaces.ts` so they can be
|
||||
imported without pulling in component code.
|
||||
|
||||
```typescript
|
||||
// ✅ Good — props interface in the same file as the component
|
||||
// UserCard.tsx
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ✅ Good — shared types in interfaces.ts
|
||||
// interfaces.ts
|
||||
export interface User {
|
||||
id: string
|
||||
name: string
|
||||
role: UserRole
|
||||
}
|
||||
|
||||
export type UserRole = "admin" | "member" | "viewer"
|
||||
|
||||
// ❌ Bad — inline prop types
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing. When a library component exposes a padding prop
|
||||
(e.g., `paddingVariant`), use that prop instead of wrapping it in a `<div>` with padding classes.
|
||||
If a library component does not expose a padding override and you find yourself adding a wrapper
|
||||
div for spacing, consider updating the library component to accept one.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins, and minimize wrapper
|
||||
divs that exist solely for spacing.
|
||||
|
||||
```typescript
|
||||
// ✅ Good — use the component's padding prop
|
||||
<ContentAction paddingVariant="md" ... />
|
||||
|
||||
// ✅ Good — padding utilities when no component prop exists
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad — wrapper div just for spacing
|
||||
<div className="p-4">
|
||||
<ContentAction ... />
|
||||
</div>
|
||||
|
||||
// ❌ Bad — margins
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
## 5. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 6. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
AGENTS.md
|
||||
281
web/STANDARDS.md
Normal file
281
web/STANDARDS.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# Web Standards
|
||||
|
||||
This document outlines the coding standards and best practices for the `web` directory Next.js project.
|
||||
|
||||
## 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
## 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
## 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
## 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
## 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
## 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
## 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
@@ -1,30 +1,33 @@
|
||||
"use client";
|
||||
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
|
||||
interface ActionsContainerProps {
|
||||
type: "head" | "cell";
|
||||
children: React.ReactNode;
|
||||
size?: TableSize;
|
||||
/** Pass-through click handler (e.g. stopPropagation on body cells). */
|
||||
onClick?: (e: React.MouseEvent) => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function ActionsContainer({
|
||||
type,
|
||||
children,
|
||||
size,
|
||||
onClick,
|
||||
}: ActionsContainerProps) {
|
||||
const size = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
|
||||
const Tag = type === "head" ? "th" : "td";
|
||||
|
||||
return (
|
||||
<Tag
|
||||
className="tbl-actions"
|
||||
data-type={type}
|
||||
data-size={size}
|
||||
data-size={resolvedSize}
|
||||
onClick={onClick}
|
||||
>
|
||||
<div className="flex h-full items-center justify-end">{children}</div>
|
||||
<div className="flex h-full items-center justify-center">{children}</div>
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
type SortingState,
|
||||
} from "@tanstack/react-table";
|
||||
import { Button, LineItemButton } from "@opal/components";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Divider from "@/refresh-components/Divider";
|
||||
@@ -21,6 +20,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
interface SortingPopoverProps<TData extends RowData = RowData> {
|
||||
table: Table<TData>;
|
||||
sorting: SortingState;
|
||||
size?: "md" | "lg";
|
||||
footerText?: string;
|
||||
ascendingLabel?: string;
|
||||
descendingLabel?: string;
|
||||
@@ -29,11 +29,11 @@ interface SortingPopoverProps<TData extends RowData = RowData> {
|
||||
function SortingPopover<TData extends RowData>({
|
||||
table,
|
||||
sorting,
|
||||
size = "lg",
|
||||
footerText,
|
||||
ascendingLabel = "Ascending",
|
||||
descendingLabel = "Descending",
|
||||
}: SortingPopoverProps<TData>) {
|
||||
const size = useTableSize();
|
||||
const [open, setOpen] = useState(false);
|
||||
const sortableColumns = table
|
||||
.getAllLeafColumns()
|
||||
@@ -158,6 +158,7 @@ function SortingPopover<TData extends RowData>({
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CreateSortingColumnOptions {
|
||||
size?: "md" | "lg";
|
||||
footerText?: string;
|
||||
ascendingLabel?: string;
|
||||
descendingLabel?: string;
|
||||
@@ -176,6 +177,7 @@ function createSortingColumn<TData>(
|
||||
<SortingPopover
|
||||
table={table}
|
||||
sorting={table.getState().sorting}
|
||||
size={options?.size}
|
||||
footerText={options?.footerText}
|
||||
ascendingLabel={options?.ascendingLabel}
|
||||
descendingLabel={options?.descendingLabel}
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
type VisibilityState,
|
||||
} from "@tanstack/react-table";
|
||||
import { Button, LineItemButton, Tag } from "@opal/components";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgColumn, SvgCheck } from "@opal/icons";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Divider from "@/refresh-components/Divider";
|
||||
@@ -20,13 +19,14 @@ import Divider from "@/refresh-components/Divider";
|
||||
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
|
||||
table: Table<TData>;
|
||||
columnVisibility: VisibilityState;
|
||||
size?: "md" | "lg";
|
||||
}
|
||||
|
||||
function ColumnVisibilityPopover<TData extends RowData>({
|
||||
table,
|
||||
columnVisibility,
|
||||
size = "lg",
|
||||
}: ColumnVisibilityPopoverProps<TData>) {
|
||||
const size = useTableSize();
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
// User-defined columns only (exclude internal qualifier/actions)
|
||||
@@ -87,7 +87,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
|
||||
// Column definition factory
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
|
||||
interface CreateColumnVisibilityColumnOptions {
|
||||
size?: "md" | "lg";
|
||||
}
|
||||
|
||||
function createColumnVisibilityColumn<TData>(
|
||||
options?: CreateColumnVisibilityColumnOptions
|
||||
): ColumnDef<TData, unknown> {
|
||||
return {
|
||||
id: "__columnVisibility",
|
||||
size: 44,
|
||||
@@ -98,6 +104,7 @@ function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
|
||||
<ColumnVisibilityPopover
|
||||
table={table}
|
||||
columnVisibility={table.getState().columnVisibility}
|
||||
size={options?.size}
|
||||
/>
|
||||
),
|
||||
cell: () => null,
|
||||
|
||||
@@ -57,10 +57,9 @@ function DragOverlayRowInner<TData>({
|
||||
<QualifierContainer key={cell.id} type="cell">
|
||||
<TableQualifier
|
||||
content={qualifierColumn.content}
|
||||
icon={qualifierColumn.getContent?.(row.original)}
|
||||
initials={qualifierColumn.getInitials?.(row.original)}
|
||||
icon={qualifierColumn.getIcon?.(row.original)}
|
||||
imageSrc={qualifierColumn.getImageSrc?.(row.original)}
|
||||
imageAlt={qualifierColumn.getImageAlt?.(row.original)}
|
||||
background={qualifierColumn.background}
|
||||
selectable={isSelectable}
|
||||
selected={isSelectable && row.getIsSelected()}
|
||||
/>
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@opal/utils";
|
||||
import { Button, Pagination, SelectButton } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgEye, SvgXCircle } from "@opal/icons";
|
||||
import type { ReactNode } from "react";
|
||||
|
||||
@@ -43,6 +45,9 @@ interface FooterSelectionModeProps {
|
||||
onPageChange: (page: number) => void;
|
||||
/** Unit label for count pagination. @default "items" */
|
||||
units?: string;
|
||||
/** Controls overall footer sizing. `"lg"` (default) or `"md"`. */
|
||||
size?: TableSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -68,6 +73,7 @@ interface FooterSummaryModeProps {
|
||||
leftExtra?: ReactNode;
|
||||
/** Unit label for the summary text, e.g. "users". */
|
||||
units?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -104,7 +110,11 @@ export default function Footer(props: FooterProps) {
|
||||
const isSmall = resolvedSize === "md";
|
||||
return (
|
||||
<div
|
||||
className="table-footer flex w-full items-center justify-between border-t border-border-01"
|
||||
className={cn(
|
||||
"table-footer",
|
||||
"flex w-full items-center justify-between border-t border-border-01",
|
||||
props.className
|
||||
)}
|
||||
data-size={resolvedSize}
|
||||
>
|
||||
{/* Left side */}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
|
||||
interface QualifierContainerProps {
|
||||
type: "head" | "cell";
|
||||
children?: React.ReactNode;
|
||||
size?: TableSize;
|
||||
/** Pass-through click handler (e.g. stopPropagation on body cells). */
|
||||
onClick?: (e: React.MouseEvent) => void;
|
||||
}
|
||||
@@ -12,9 +12,11 @@ interface QualifierContainerProps {
|
||||
export default function QualifierContainer({
|
||||
type,
|
||||
children,
|
||||
size,
|
||||
onClick,
|
||||
}: QualifierContainerProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
|
||||
const Tag = type === "head" ? "th" : "td";
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ row selection, drag-and-drop reordering, and server-side mode.
|
||||
|
||||
```tsx
|
||||
import { Table, createTableColumns } from "@opal/components";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
|
||||
interface User {
|
||||
id: string;
|
||||
@@ -19,10 +18,11 @@ interface User {
|
||||
const tc = createTableColumns<User>();
|
||||
|
||||
const columns = [
|
||||
tc.qualifier({ content: "icon", getContent: () => SvgUser }),
|
||||
tc.qualifier({ content: "avatar-user", getInitials: (r) => r.name?.[0] ?? "?" }),
|
||||
tc.column("email", {
|
||||
header: "Name",
|
||||
weight: 22,
|
||||
minWidth: 140,
|
||||
cell: (email, row) => <span>{row.name ?? email}</span>,
|
||||
}),
|
||||
tc.column("status", {
|
||||
@@ -40,7 +40,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
columns={columns}
|
||||
getRowId={(r) => r.id}
|
||||
pageSize={10}
|
||||
footer={{}}
|
||||
footer={{ mode: "summary" }}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -55,7 +55,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
| `getRowId` | `(row: TData) => string` | required | Unique row identifier |
|
||||
| `pageSize` | `number` | `10` | Rows per page (`Infinity` disables pagination) |
|
||||
| `size` | `"md" \| "lg"` | `"lg"` | Density variant |
|
||||
| `footer` | `DataTableFooterConfig` | — | Footer configuration (mode is derived from `selectionBehavior`) |
|
||||
| `footer` | `DataTableFooterConfig` | — | Footer mode (`"selection"` or `"summary"`) |
|
||||
| `initialSorting` | `SortingState` | — | Initial sort state |
|
||||
| `initialColumnVisibility` | `VisibilityState` | — | Initial column visibility |
|
||||
| `draggable` | `DataTableDraggableConfig` | — | Enable drag-and-drop reordering |
|
||||
@@ -63,6 +63,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
| `onRowClick` | `(row: TData) => void` | — | Row click handler |
|
||||
| `searchTerm` | `string` | — | Global text filter |
|
||||
| `height` | `number \| string` | — | Max scrollable height |
|
||||
| `headerBackground` | `string` | — | Sticky header background |
|
||||
| `serverSide` | `ServerSideConfig` | — | Server-side pagination/sorting/filtering |
|
||||
| `emptyState` | `ReactNode` | — | Empty state content |
|
||||
|
||||
@@ -75,8 +76,7 @@ function UsersTable({ users }: { users: User[] }) {
|
||||
- `tc.displayColumn(opts)` — non-accessor custom column
|
||||
- `tc.actions(opts)` — trailing actions column with visibility/sorting popovers
|
||||
|
||||
## Footer
|
||||
## Footer Modes
|
||||
|
||||
The footer mode is derived automatically from `selectionBehavior`:
|
||||
- **Selection footer** (when `selectionBehavior` is `"single-select"` or `"multi-select"`) — shows selection count, optional view/clear buttons, count pagination
|
||||
- **Summary footer** (when `selectionBehavior` is `"no-select"` or omitted) — shows "Showing X\~Y of Z", list pagination, optional extra element
|
||||
- **`"selection"`** — shows selection count, optional view/clear buttons, count pagination
|
||||
- **`"summary"`** — shows "Showing X~Y of Z", list pagination, optional extra element
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Table, createTableColumns } from "@opal/components";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sample data
|
||||
@@ -109,14 +108,17 @@ const tc = createTableColumns<User>();
|
||||
|
||||
const columns = [
|
||||
tc.qualifier({
|
||||
content: "icon",
|
||||
getContent: () => SvgUser,
|
||||
background: true,
|
||||
content: "avatar-user",
|
||||
getInitials: (r) =>
|
||||
r.name
|
||||
.split(" ")
|
||||
.map((n) => n[0])
|
||||
.join(""),
|
||||
}),
|
||||
tc.column("name", { header: "Name", weight: 25 }),
|
||||
tc.column("email", { header: "Email", weight: 30 }),
|
||||
tc.column("role", { header: "Role", weight: 15 }),
|
||||
tc.column("status", { header: "Status", weight: 15 }),
|
||||
tc.column("name", { header: "Name", weight: 25, minWidth: 120 }),
|
||||
tc.column("email", { header: "Email", weight: 30, minWidth: 160 }),
|
||||
tc.column("role", { header: "Role", weight: 15, minWidth: 80 }),
|
||||
tc.column("status", { header: "Status", weight: 15, minWidth: 80 }),
|
||||
tc.actions(),
|
||||
];
|
||||
|
||||
@@ -140,7 +142,7 @@ export const Default: Story = {
|
||||
columns={columns}
|
||||
getRowId={(r) => r.id}
|
||||
pageSize={8}
|
||||
footer={{}}
|
||||
footer={{ mode: "summary" }}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
|
||||
interface TableCellProps
|
||||
extends WithoutStyles<React.TdHTMLAttributes<HTMLTableCellElement>> {
|
||||
children: React.ReactNode;
|
||||
size?: TableSize;
|
||||
/** Explicit pixel width for the cell. */
|
||||
width?: number;
|
||||
}
|
||||
|
||||
export default function TableCell({
|
||||
size,
|
||||
width,
|
||||
children,
|
||||
...props
|
||||
}: TableCellProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
return (
|
||||
<td
|
||||
className="tbl-cell overflow-hidden"
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
|
||||
|
||||
@@ -12,15 +9,20 @@ import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
|
||||
|
||||
type TableSize = Extract<SizeVariants, "md" | "lg">;
|
||||
type TableVariant = "rows" | "cards";
|
||||
type TableQualifier = "simple" | "avatar" | "icon";
|
||||
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
|
||||
|
||||
interface TableProps
|
||||
extends WithoutStyles<React.TableHTMLAttributes<HTMLTableElement>> {
|
||||
ref?: React.Ref<HTMLTableElement>;
|
||||
/** Size preset for the table. @default "lg" */
|
||||
size?: TableSize;
|
||||
/** Visual row variant. @default "cards" */
|
||||
variant?: TableVariant;
|
||||
/** Row selection behavior. @default "no-select" */
|
||||
selectionBehavior?: SelectionBehavior;
|
||||
/** Leading qualifier column type. @default null */
|
||||
qualifier?: TableQualifier;
|
||||
/** Height behavior. `"fit"` = shrink to content, `"full"` = fill available space. */
|
||||
heightVariant?: ExtremaSizeVariants;
|
||||
/** Explicit pixel width for the table (e.g. from `table.getTotalSize()`).
|
||||
@@ -36,21 +38,23 @@ interface TableProps
|
||||
|
||||
function Table({
|
||||
ref,
|
||||
size = "lg",
|
||||
variant = "cards",
|
||||
selectionBehavior = "no-select",
|
||||
qualifier = "simple",
|
||||
heightVariant,
|
||||
width,
|
||||
...props
|
||||
}: TableProps) {
|
||||
const size = useTableSize();
|
||||
return (
|
||||
<table
|
||||
ref={ref}
|
||||
className={cn("border-separate border-spacing-0", !width && "min-w-full")}
|
||||
style={{ width }}
|
||||
style={{ tableLayout: "fixed", width }}
|
||||
data-size={size}
|
||||
data-variant={variant}
|
||||
data-selection={selectionBehavior}
|
||||
data-qualifier={qualifier}
|
||||
data-height={heightVariant}
|
||||
{...props}
|
||||
/>
|
||||
@@ -58,4 +62,10 @@ function Table({
|
||||
}
|
||||
|
||||
export default Table;
|
||||
export type { TableProps, TableSize, TableVariant, SelectionBehavior };
|
||||
export type {
|
||||
TableProps,
|
||||
TableSize,
|
||||
TableVariant,
|
||||
TableQualifier,
|
||||
SelectionBehavior,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgChevronDown, SvgChevronUp, SvgHandle, SvgSort } from "@opal/icons";
|
||||
@@ -29,6 +30,8 @@ interface TableHeadCustomProps {
|
||||
icon?: (sorted: SortDirection) => IconFunctionComponent;
|
||||
/** Text alignment for the column. Defaults to `"left"`. */
|
||||
alignment?: "left" | "center" | "right";
|
||||
/** Cell density. `"md"` uses tighter padding for denser layouts. */
|
||||
size?: TableSize;
|
||||
/** Column width in pixels. Applied as an inline style on the `<th>`. */
|
||||
width?: number;
|
||||
/** When `true`, shows a bottom border on hover. Defaults to `true`. */
|
||||
@@ -78,11 +81,13 @@ export default function TableHead({
|
||||
resizable,
|
||||
onResizeStart,
|
||||
alignment = "left",
|
||||
size,
|
||||
width,
|
||||
bottomBorder = true,
|
||||
...thProps
|
||||
}: TableHeadProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
const isSmall = resolvedSize === "md";
|
||||
return (
|
||||
<th
|
||||
@@ -92,7 +97,9 @@ export default function TableHead({
|
||||
data-size={resolvedSize}
|
||||
data-bottom-border={bottomBorder || undefined}
|
||||
>
|
||||
<div className="flex items-center gap-1">
|
||||
<div
|
||||
className={cn("flex items-center gap-1", alignmentFlexClass[alignment])}
|
||||
>
|
||||
<div className="table-head-label">
|
||||
<Text
|
||||
mainUiAction={!isSmall}
|
||||
|
||||
@@ -3,13 +3,19 @@
|
||||
import React from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { QualifierContentType } from "@opal/components/table/types";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
interface TableQualifierProps {
|
||||
className?: string;
|
||||
/** Content type displayed in the qualifier */
|
||||
content: QualifierContentType;
|
||||
/** Size variant */
|
||||
size?: TableSize;
|
||||
/** Disables interaction */
|
||||
disabled?: boolean;
|
||||
/** Whether to show a selection checkbox overlay */
|
||||
@@ -18,33 +24,54 @@ interface TableQualifierProps {
|
||||
selected?: boolean;
|
||||
/** Called when the checkbox is toggled */
|
||||
onSelectChange?: (selected: boolean) => void;
|
||||
/** Icon component to render (for "icon" content). */
|
||||
/** Icon component to render (for "icon" content type) */
|
||||
icon?: IconFunctionComponent;
|
||||
/** Image source URL (for "image" content). */
|
||||
/** Image source URL (for "image" content type) */
|
||||
imageSrc?: string;
|
||||
/** Image alt text (for "image" content). */
|
||||
/** Image alt text */
|
||||
imageAlt?: string;
|
||||
/** Show a tinted background container behind the content. */
|
||||
background?: boolean;
|
||||
/** User initials (for "avatar-user" content type) */
|
||||
initials?: string;
|
||||
}
|
||||
|
||||
const iconSizes = {
|
||||
lg: 28,
|
||||
md: 24,
|
||||
lg: 16,
|
||||
md: 14,
|
||||
} as const;
|
||||
|
||||
function getOverlayStyles(selected: boolean, disabled: boolean) {
|
||||
function getQualifierStyles(selected: boolean, disabled: boolean) {
|
||||
if (disabled) {
|
||||
return selected ? "flex bg-action-link-00" : "hidden";
|
||||
return {
|
||||
container: "bg-background-neutral-03",
|
||||
icon: "stroke-text-02",
|
||||
overlay: selected ? "flex bg-action-link-00" : "hidden",
|
||||
overlayImage: selected ? "flex bg-mask-01 backdrop-blur-02" : "hidden",
|
||||
};
|
||||
}
|
||||
|
||||
if (selected) {
|
||||
return "flex bg-action-link-00";
|
||||
return {
|
||||
container: "bg-action-link-00",
|
||||
icon: "stroke-text-03",
|
||||
overlay: "flex bg-action-link-00",
|
||||
overlayImage: "flex bg-mask-01 backdrop-blur-02",
|
||||
};
|
||||
}
|
||||
return "flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01";
|
||||
|
||||
return {
|
||||
container: "bg-background-tint-01",
|
||||
icon: "stroke-text-03",
|
||||
overlay:
|
||||
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01",
|
||||
overlayImage:
|
||||
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-mask-01 group-hover/row:backdrop-blur-02 group-focus-within/row:backdrop-blur-02",
|
||||
};
|
||||
}
|
||||
|
||||
function TableQualifier({
|
||||
className,
|
||||
content,
|
||||
size,
|
||||
disabled = false,
|
||||
selectable = false,
|
||||
selected = false,
|
||||
@@ -52,67 +79,100 @@ function TableQualifier({
|
||||
icon: Icon,
|
||||
imageSrc,
|
||||
imageAlt = "",
|
||||
background = false,
|
||||
initials,
|
||||
}: TableQualifierProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
const isRound = content === "avatar-icon" || content === "avatar-user";
|
||||
const iconSize = iconSizes[resolvedSize];
|
||||
const overlayStyles = getOverlayStyles(selected, disabled);
|
||||
const styles = getQualifierStyles(selected, disabled);
|
||||
|
||||
function renderContent() {
|
||||
switch (content) {
|
||||
case "icon":
|
||||
return Icon ? <Icon size={iconSize} /> : null;
|
||||
return Icon ? <Icon size={iconSize} className={styles.icon} /> : null;
|
||||
|
||||
case "simple":
|
||||
return null;
|
||||
|
||||
case "image":
|
||||
return imageSrc ? (
|
||||
<img
|
||||
src={imageSrc}
|
||||
alt={imageAlt}
|
||||
className="h-full w-full rounded-08 object-cover"
|
||||
className={cn(
|
||||
"h-full w-full object-cover",
|
||||
isRound ? "rounded-full" : "rounded-08"
|
||||
)}
|
||||
/>
|
||||
) : null;
|
||||
|
||||
case "simple":
|
||||
case "avatar-icon":
|
||||
return <SvgUser size={iconSize} className={styles.icon} />;
|
||||
|
||||
case "avatar-user":
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center rounded-full bg-background-neutral-inverted-00",
|
||||
resolvedSize === "lg" ? "h-7 w-7" : "h-6 w-6"
|
||||
)}
|
||||
>
|
||||
<Text
|
||||
inverted
|
||||
secondaryAction
|
||||
text05
|
||||
className="select-none uppercase"
|
||||
>
|
||||
{initials}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const inner = renderContent();
|
||||
const showBackground = background && content !== "simple";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative inline-flex shrink-0 items-center justify-center",
|
||||
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
|
||||
disabled ? "cursor-not-allowed" : "cursor-default"
|
||||
disabled ? "cursor-not-allowed" : "cursor-default",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{showBackground ? (
|
||||
{/* Inner qualifier container — no background for "simple" */}
|
||||
{content !== "simple" && (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center overflow-hidden rounded-08 transition-colors",
|
||||
"flex items-center justify-center overflow-hidden transition-colors",
|
||||
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
|
||||
disabled
|
||||
? "bg-background-neutral-03"
|
||||
: selected
|
||||
? "bg-action-link-00"
|
||||
: "bg-background-tint-01"
|
||||
isRound ? "rounded-full" : "rounded-08",
|
||||
styles.container,
|
||||
content === "image" && disabled && !selected && "opacity-50"
|
||||
)}
|
||||
>
|
||||
{inner}
|
||||
{renderContent()}
|
||||
</div>
|
||||
) : (
|
||||
inner
|
||||
)}
|
||||
|
||||
{/* Selection overlay */}
|
||||
{selectable && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-0 items-center justify-center rounded-08",
|
||||
content === "simple" ? "flex" : overlayStyles
|
||||
"absolute inset-0 items-center justify-center",
|
||||
content === "simple"
|
||||
? "flex"
|
||||
: isRound
|
||||
? "rounded-full"
|
||||
: "rounded-08",
|
||||
content === "simple"
|
||||
? "flex"
|
||||
: content === "image"
|
||||
? styles.overlayImage
|
||||
: styles.overlay
|
||||
)}
|
||||
>
|
||||
<Checkbox
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { cn } from "@opal/utils";
|
||||
import { useTableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
import type { WithoutStyles } from "@/types";
|
||||
import { useSortable } from "@dnd-kit/sortable";
|
||||
import { CSS } from "@dnd-kit/utilities";
|
||||
@@ -11,7 +12,7 @@ import { SvgHandle } from "@opal/icons";
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export interface TableRowProps
|
||||
interface TableRowProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLTableRowElement>> {
|
||||
ref?: React.Ref<HTMLTableRowElement>;
|
||||
selected?: boolean;
|
||||
@@ -21,6 +22,8 @@ export interface TableRowProps
|
||||
sortableId?: string;
|
||||
/** Show drag handle overlay. Defaults to true when sortableId is set. */
|
||||
showDragHandle?: boolean;
|
||||
/** Size variant for the drag handle */
|
||||
size?: TableSize;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -30,13 +33,15 @@ export interface TableRowProps
|
||||
function SortableTableRow({
|
||||
sortableId,
|
||||
showDragHandle = true,
|
||||
size,
|
||||
selected,
|
||||
disabled,
|
||||
ref: _externalRef,
|
||||
children,
|
||||
...props
|
||||
}: TableRowProps) {
|
||||
const resolvedSize = useTableSize();
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = size ?? contextSize;
|
||||
|
||||
const {
|
||||
attributes,
|
||||
@@ -100,9 +105,10 @@ function SortableTableRow({
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function TableRow({
|
||||
function TableRow({
|
||||
sortableId,
|
||||
showDragHandle,
|
||||
size,
|
||||
selected,
|
||||
disabled,
|
||||
ref,
|
||||
@@ -113,6 +119,7 @@ export default function TableRow({
|
||||
<SortableTableRow
|
||||
sortableId={sortableId}
|
||||
showDragHandle={showDragHandle}
|
||||
size={size}
|
||||
selected={selected}
|
||||
disabled={disabled}
|
||||
ref={ref}
|
||||
@@ -131,3 +138,6 @@ export default function TableRow({
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default TableRow;
|
||||
export type { TableRowProps };
|
||||
|
||||
@@ -25,14 +25,18 @@ import type { SortDirection } from "@opal/components/table/TableHead";
|
||||
interface QualifierConfig<TData> {
|
||||
/** Content type for body-row `<TableQualifier>`. @default "simple" */
|
||||
content?: QualifierContentType;
|
||||
/** Return the icon component to render for a row (for "icon" content). */
|
||||
getContent?: (row: TData) => IconFunctionComponent;
|
||||
/** Return the image URL to render for a row (for "image" content). */
|
||||
/** Content type for the header `<TableQualifier>`. @default "simple" */
|
||||
headerContentType?: QualifierContentType;
|
||||
/** Extract initials from a row (for "avatar-user" content). */
|
||||
getInitials?: (row: TData) => string;
|
||||
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
|
||||
getIcon?: (row: TData) => IconFunctionComponent;
|
||||
/** Extract image src from a row (for "image" content). */
|
||||
getImageSrc?: (row: TData) => string;
|
||||
/** Return the image alt text for a row (for "image" content). @default "" */
|
||||
getImageAlt?: (row: TData) => string;
|
||||
/** Show a tinted background container behind the content. @default false */
|
||||
background?: boolean;
|
||||
/** Whether to show selection checkboxes on the qualifier. @default true */
|
||||
selectable?: boolean;
|
||||
/** Whether to render qualifier content in the header. @default true */
|
||||
header?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -54,6 +58,8 @@ interface DataColumnConfig<TData, TValue> {
|
||||
icon?: (sorted: SortDirection) => IconFunctionComponent;
|
||||
/** Column weight for proportional distribution. @default 20 */
|
||||
weight?: number;
|
||||
/** Minimum column width in pixels. @default 50 */
|
||||
minWidth?: number;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -126,9 +132,9 @@ interface TableColumnsBuilder<TData> {
|
||||
* ```ts
|
||||
* const tc = createTableColumns<TeamMember>();
|
||||
* const columns = [
|
||||
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
|
||||
* tc.column("name", { header: "Name", weight: 23 }),
|
||||
* tc.column("email", { header: "Email", weight: 28 }),
|
||||
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
|
||||
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
|
||||
* tc.column("email", { header: "Email", weight: 28, minWidth: 150 }),
|
||||
* tc.actions(),
|
||||
* ];
|
||||
* ```
|
||||
@@ -156,10 +162,12 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
|
||||
width: (size: TableSize) =>
|
||||
size === "md" ? { fixed: 36 } : { fixed: 44 },
|
||||
content,
|
||||
getContent: config?.getContent,
|
||||
headerContentType: config?.headerContentType,
|
||||
getInitials: config?.getInitials,
|
||||
getIcon: config?.getIcon,
|
||||
getImageSrc: config?.getImageSrc,
|
||||
getImageAlt: config?.getImageAlt,
|
||||
background: config?.background,
|
||||
selectable: config?.selectable,
|
||||
header: config?.header,
|
||||
};
|
||||
},
|
||||
|
||||
@@ -175,6 +183,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
|
||||
enableHiding = true,
|
||||
icon,
|
||||
weight = 20,
|
||||
minWidth = 50,
|
||||
} = config;
|
||||
|
||||
const def = helper.accessor(accessor as any, {
|
||||
@@ -192,7 +201,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
|
||||
kind: "data",
|
||||
id: accessor as string,
|
||||
def,
|
||||
width: { weight, minWidth: Math.max(header.length * 8 + 40, 80) },
|
||||
width: { weight, minWidth },
|
||||
icon,
|
||||
};
|
||||
},
|
||||
|
||||
@@ -39,12 +39,15 @@ import type {
|
||||
import type { TableSize } from "@opal/components/table/TableSizeContext";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SelectionBehavior
|
||||
// Qualifier × SelectionBehavior
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type Qualifier = "simple" | "avatar" | "icon";
|
||||
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
|
||||
|
||||
export type DataTableProps<TData> = BaseDataTableProps<TData> & {
|
||||
/** Leading qualifier column type. @default "simple" */
|
||||
qualifier?: Qualifier;
|
||||
/** Row selection behavior. @default "no-select" */
|
||||
selectionBehavior?: SelectionBehavior;
|
||||
};
|
||||
@@ -128,8 +131,8 @@ function processColumns<TData>(
|
||||
* ```tsx
|
||||
* const tc = createTableColumns<TeamMember>();
|
||||
* const columns = [
|
||||
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
|
||||
* tc.column("name", { header: "Name", weight: 23 }),
|
||||
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
|
||||
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
|
||||
* tc.column("email", { header: "Email", weight: 28 }),
|
||||
* tc.actions(),
|
||||
* ];
|
||||
@@ -149,11 +152,13 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
footer,
|
||||
size = "lg",
|
||||
variant = "cards",
|
||||
qualifier = "simple",
|
||||
selectionBehavior = "no-select",
|
||||
onSelectionChange,
|
||||
onRowClick,
|
||||
searchTerm,
|
||||
height,
|
||||
headerBackground,
|
||||
serverSide,
|
||||
emptyState,
|
||||
} = props;
|
||||
@@ -161,15 +166,11 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
|
||||
|
||||
// Whether the qualifier column should exist in the DOM.
|
||||
// Derived from the column definitions: if a qualifier column exists with
|
||||
// content !== "simple", always show it. If content === "simple" (or no
|
||||
// qualifier column defined), show only for multi-select (checkboxes).
|
||||
const qualifierColDef = columns.find(
|
||||
(c): c is OnyxQualifierColumn<TData> => c.kind === "qualifier"
|
||||
);
|
||||
// "simple" only gets a qualifier column for multi-select (checkboxes).
|
||||
// "simple" + no-select/single-select = no qualifier column — single-select
|
||||
// uses row-level background coloring instead.
|
||||
const hasQualifierColumn =
|
||||
(qualifierColDef != null && qualifierColDef.content !== "simple") ||
|
||||
selectionBehavior === "multi-select";
|
||||
qualifier !== "simple" || selectionBehavior === "multi-select";
|
||||
|
||||
// 1. Process columns (memoized on columns + size)
|
||||
const { tanstackColumns, widthConfig, qualifierColumn, columnKindMap } =
|
||||
@@ -348,9 +349,15 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
overflowY: "auto" as const,
|
||||
}
|
||||
: undefined),
|
||||
...(headerBackground
|
||||
? ({
|
||||
"--table-header-bg": headerBackground,
|
||||
} as React.CSSProperties)
|
||||
: undefined),
|
||||
}}
|
||||
>
|
||||
<TableElement
|
||||
size={size}
|
||||
variant={variant}
|
||||
selectionBehavior={selectionBehavior}
|
||||
width={
|
||||
@@ -412,12 +419,14 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
columnVisibility={
|
||||
table.getState().columnVisibility
|
||||
}
|
||||
size={size}
|
||||
/>
|
||||
)}
|
||||
{actionsDef.showSorting !== false && (
|
||||
<SortingPopover
|
||||
table={table}
|
||||
sorting={table.getState().sorting}
|
||||
size={size}
|
||||
footerText={actionsDef.sortingFooterText}
|
||||
/>
|
||||
)}
|
||||
@@ -532,6 +541,12 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
if (cellColDef?.kind === "qualifier") {
|
||||
const qDef = cellColDef as OnyxQualifierColumn<TData>;
|
||||
|
||||
// Resolve content based on the qualifier prop:
|
||||
// - "simple" renders nothing (checkbox only when selectable)
|
||||
// - "avatar"/"icon" render from column config
|
||||
const qualifierContent =
|
||||
qualifier === "simple" ? "simple" : qDef.content;
|
||||
|
||||
return (
|
||||
<QualifierContainer
|
||||
key={cell.id}
|
||||
@@ -539,11 +554,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<TableQualifier
|
||||
content={qDef.content}
|
||||
icon={qDef.getContent?.(row.original)}
|
||||
content={qualifierContent}
|
||||
initials={qDef.getInitials?.(row.original)}
|
||||
icon={qDef.getIcon?.(row.original)}
|
||||
imageSrc={qDef.getImageSrc?.(row.original)}
|
||||
imageAlt={qDef.getImageAlt?.(row.original)}
|
||||
background={qDef.background}
|
||||
selectable={showQualifierCheckbox}
|
||||
selected={
|
||||
showQualifierCheckbox && row.getIsSelected()
|
||||
|
||||
@@ -277,7 +277,7 @@ function createSplitterResizeHandler(
|
||||
* const { containerRef, columnWidths, createResizeHandler } = useColumnWidths({
|
||||
* headers: table.getHeaderGroups()[0].headers,
|
||||
* fixedColumnIds: new Set(["actions"]),
|
||||
* columnMinWidths: { name: 72, status: 80 },
|
||||
* columnMinWidths: { name: 120, status: 80 },
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
|
||||
@@ -25,7 +25,8 @@
|
||||
/* ---- TableHead ---- */
|
||||
|
||||
.table-head {
|
||||
@apply relative;
|
||||
@apply relative sticky top-0 z-20;
|
||||
background: var(--table-header-bg, transparent);
|
||||
}
|
||||
.table-head[data-size="lg"] {
|
||||
@apply px-2 py-1;
|
||||
@@ -129,7 +130,8 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
|
||||
/* ---- QualifierContainer ---- */
|
||||
|
||||
.tbl-qualifier[data-type="head"] {
|
||||
@apply w-px whitespace-nowrap py-1;
|
||||
@apply w-px whitespace-nowrap py-1 sticky top-0 z-20;
|
||||
background: var(--table-header-bg, transparent);
|
||||
}
|
||||
.tbl-qualifier[data-type="head"][data-size="md"] {
|
||||
@apply py-0.5;
|
||||
@@ -145,10 +147,11 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
|
||||
/* ---- ActionsContainer ---- */
|
||||
|
||||
.tbl-actions {
|
||||
@apply w-px whitespace-nowrap px-1;
|
||||
@apply sticky right-0 w-px whitespace-nowrap px-1;
|
||||
}
|
||||
.tbl-actions[data-type="head"] {
|
||||
@apply px-2 py-1;
|
||||
@apply z-30 sticky top-0 px-2 py-1;
|
||||
background: var(--table-header-bg, transparent);
|
||||
}
|
||||
|
||||
/* ---- Footer ---- */
|
||||
|
||||
@@ -30,7 +30,12 @@ export type ColumnWidth = DataColumnWidth | FixedColumnWidth;
|
||||
// Column kind discriminant
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export type QualifierContentType = "simple" | "icon" | "image";
|
||||
export type QualifierContentType =
|
||||
| "icon"
|
||||
| "simple"
|
||||
| "image"
|
||||
| "avatar-icon"
|
||||
| "avatar-user";
|
||||
|
||||
export type OnyxColumnKind = "qualifier" | "data" | "display" | "actions";
|
||||
|
||||
@@ -51,14 +56,18 @@ export interface OnyxQualifierColumn<TData> extends OnyxColumnBase<TData> {
|
||||
kind: "qualifier";
|
||||
/** Content type for body-row `<TableQualifier>`. */
|
||||
content: QualifierContentType;
|
||||
/** Return the icon component to render for a row (for "icon" content). */
|
||||
getContent?: (row: TData) => IconFunctionComponent;
|
||||
/** Return the image URL to render for a row (for "image" content). */
|
||||
/** Content type for the header `<TableQualifier>`. @default "simple" */
|
||||
headerContentType?: QualifierContentType;
|
||||
/** Extract initials from a row (for "avatar-user" content). */
|
||||
getInitials?: (row: TData) => string;
|
||||
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
|
||||
getIcon?: (row: TData) => IconFunctionComponent;
|
||||
/** Extract image src from a row (for "image" content). */
|
||||
getImageSrc?: (row: TData) => string;
|
||||
/** Return the image alt text for a row (for "image" content). @default "" */
|
||||
getImageAlt?: (row: TData) => string;
|
||||
/** Show a tinted background container behind the content. @default false */
|
||||
background?: boolean;
|
||||
/** Whether to show selection checkboxes on the qualifier. @default true */
|
||||
selectable?: boolean;
|
||||
/** Whether to render qualifier content in the header. @default true */
|
||||
header?: boolean;
|
||||
}
|
||||
|
||||
/** Data column — accessor-based column with sorting/resizing. */
|
||||
@@ -165,6 +174,9 @@ export interface DataTableProps<TData> {
|
||||
* Accepts a pixel number (e.g. `300`) or a CSS value string (e.g. `"50vh"`).
|
||||
*/
|
||||
height?: number | string;
|
||||
/** Background color for the sticky header row, preventing rows from showing
|
||||
* through when scrolling. Accepts any CSS color value. */
|
||||
headerBackground?: string;
|
||||
/**
|
||||
* Enable server-side mode. When provided:
|
||||
* - TanStack uses manualPagination/manualSorting/manualFiltering
|
||||
|
||||
@@ -45,7 +45,7 @@ function MemoryTagWithTooltip({
|
||||
<MemoriesModal
|
||||
initialTargetMemoryId={memoryId}
|
||||
initialTargetIndex={memoryIndex}
|
||||
highlightOnOpen
|
||||
highlightFirstOnOpen
|
||||
/>
|
||||
</memoriesModal.Provider>
|
||||
{memoriesModal.isOpen ? (
|
||||
@@ -56,14 +56,8 @@ function MemoryTagWithTooltip({
|
||||
side="bottom"
|
||||
className="bg-background-neutral-00 text-text-01 shadow-md max-w-[17.5rem] p-1"
|
||||
tooltip={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="start"
|
||||
padding={0.25}
|
||||
gap={0.25}
|
||||
height="auto"
|
||||
>
|
||||
<div className="p-1">
|
||||
<Section flexDirection="column" gap={0.25} height="auto">
|
||||
<div className="p-1 w-full">
|
||||
<Text as="p" secondaryBody text03>
|
||||
{memoryText}
|
||||
</Text>
|
||||
@@ -72,7 +66,6 @@ function MemoryTagWithTooltip({
|
||||
icon={SvgAddLines}
|
||||
title={operationLabel}
|
||||
sizePreset="secondary"
|
||||
paddingVariant="sm"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
rightChildren={
|
||||
|
||||
@@ -103,7 +103,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
|
||||
<MemoriesModal
|
||||
initialTargetMemoryId={memoryId}
|
||||
initialTargetIndex={index}
|
||||
highlightOnOpen
|
||||
highlightFirstOnOpen
|
||||
/>
|
||||
</memoriesModal.Provider>
|
||||
{memoryText ? (
|
||||
|
||||
@@ -123,6 +123,9 @@ export interface LLMProviderFormProps {
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
|
||||
/** The current default model name for this provider (from the global default). */
|
||||
defaultModelName?: string;
|
||||
|
||||
// Onboarding-specific (only when variant === "onboarding")
|
||||
onboardingState?: OnboardingState;
|
||||
onboardingActions?: OnboardingActions;
|
||||
|
||||
@@ -20,15 +20,10 @@
|
||||
|
||||
"use client";
|
||||
|
||||
import {
|
||||
cn,
|
||||
ensureHrefProtocol,
|
||||
INTERACTIVE_SELECTOR,
|
||||
noProp,
|
||||
} from "@/lib/utils";
|
||||
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
|
||||
import type { Components } from "react-markdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useCallback, useMemo, useRef, useState, useEffect } from "react";
|
||||
import { useCallback, useMemo, useState, useEffect } from "react";
|
||||
import { useAppBackground } from "@/providers/AppBackgroundProvider";
|
||||
import { useTheme } from "next-themes";
|
||||
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
|
||||
@@ -537,37 +532,6 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
const { isSafari } = useBrowserInfo();
|
||||
const isLightMode = resolvedTheme === "light";
|
||||
const showBackground = hasBackground && enableBackground;
|
||||
|
||||
// Track whether the chat input was focused before a mousedown, so we can
|
||||
// restore focus on mouseup if no text was selected. This preserves
|
||||
// click-drag text selection while keeping the input focused on plain clicks.
|
||||
const inputWasFocused = useRef(false);
|
||||
|
||||
const handleMouseDown = useCallback(
|
||||
(event: React.MouseEvent<HTMLDivElement>) => {
|
||||
const activeEl = document.activeElement;
|
||||
const isFocused =
|
||||
activeEl instanceof HTMLElement &&
|
||||
activeEl.id === "onyx-chat-input-textarea";
|
||||
const target = event.target;
|
||||
const isInteractive =
|
||||
target instanceof HTMLElement && !!target.closest(INTERACTIVE_SELECTOR);
|
||||
inputWasFocused.current = isFocused && !isInteractive;
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleMouseUp = useCallback(() => {
|
||||
if (!inputWasFocused.current) return;
|
||||
inputWasFocused.current = false;
|
||||
const sel = window.getSelection();
|
||||
if (sel && !sel.isCollapsed) return;
|
||||
const textarea = document.getElementById("onyx-chat-input-textarea");
|
||||
// Only restore focus if no other element has grabbed it since mousedown.
|
||||
if (textarea && document.activeElement !== textarea) {
|
||||
textarea.focus();
|
||||
}
|
||||
}, []);
|
||||
const horizontalBlurMask = `linear-gradient(
|
||||
to right,
|
||||
transparent 0%,
|
||||
@@ -585,8 +549,6 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
*/
|
||||
<div
|
||||
data-main-container
|
||||
onMouseDown={handleMouseDown}
|
||||
onMouseUp={handleMouseUp}
|
||||
className={cn(
|
||||
"@container flex flex-col h-full w-full relative overflow-hidden",
|
||||
showBackground && "bg-cover bg-center bg-fixed"
|
||||
|
||||
@@ -125,7 +125,7 @@ export const MAX_FILES_TO_SHOW = 3;
|
||||
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
|
||||
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
|
||||
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
|
||||
export const DEFAULT_AVATAR_SIZE_PX = 18;
|
||||
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
|
||||
export const HORIZON_DISTANCE_PX = 800;
|
||||
export const LOGO_FOLDED_SIZE_PX = 24;
|
||||
export const LOGO_UNFOLDED_SIZE_PX = 88;
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import { getUserInitials } from "@/lib/user";
|
||||
|
||||
describe("getUserInitials", () => {
|
||||
it("returns first letters of first two name parts", () => {
|
||||
expect(getUserInitials("Alice Smith", "alice@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("returns first two chars of a single-word name", () => {
|
||||
expect(getUserInitials("Alice", "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("handles three-word names (uses first two)", () => {
|
||||
expect(getUserInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
|
||||
});
|
||||
|
||||
it("falls back to email local part with dot separator", () => {
|
||||
expect(getUserInitials(null, "alice.smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with underscore separator", () => {
|
||||
expect(getUserInitials(null, "alice_smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with hyphen separator", () => {
|
||||
expect(getUserInitials(null, "alice-smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("uses first two chars of email local if no separator", () => {
|
||||
expect(getUserInitials(null, "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("returns null for empty email local part", () => {
|
||||
expect(getUserInitials(null, "@example.com")).toBeNull();
|
||||
});
|
||||
|
||||
it("uppercases the result", () => {
|
||||
expect(getUserInitials("john doe", "jd@test.com")).toBe("JD");
|
||||
});
|
||||
|
||||
it("trims whitespace from name", () => {
|
||||
expect(getUserInitials(" Alice Smith ", "a@test.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("returns null for numeric name parts", () => {
|
||||
expect(getUserInitials("Alice 1st", "x@test.com")).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for numeric email", () => {
|
||||
expect(getUserInitials(null, "42@domain.com")).toBeNull();
|
||||
});
|
||||
|
||||
it("falls back to email when name has non-alpha chars", () => {
|
||||
expect(getUserInitials("A1", "alice@example.com")).toBe("AL");
|
||||
});
|
||||
});
|
||||
@@ -128,54 +128,3 @@ export function getUserDisplayName(user: User | null): string {
|
||||
// If nothing works, then fall back to anonymous user name
|
||||
return "Anonymous";
|
||||
}
|
||||
|
||||
/**
|
||||
* Derive display initials from a user's name or email.
|
||||
*
|
||||
* - If a name is provided, uses the first letter of the first two words.
|
||||
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
|
||||
* - Returns `null` when no valid alpha initials can be derived.
|
||||
*/
|
||||
export function getUserInitials(
|
||||
name: string | null,
|
||||
email: string
|
||||
): string | null {
|
||||
if (name) {
|
||||
const words = name.trim().split(/\s+/);
|
||||
if (words.length >= 2) {
|
||||
const first = words[0]?.[0];
|
||||
const second = words[1]?.[0];
|
||||
if (first && second) {
|
||||
const result = (first + second).toUpperCase();
|
||||
if (/^[A-Z]{2}$/.test(result)) return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (name.trim().length >= 1) {
|
||||
const result = name.trim().slice(0, 2).toUpperCase();
|
||||
if (/^[A-Z]{1,2}$/.test(result)) return result;
|
||||
}
|
||||
}
|
||||
|
||||
const local = email.split("@")[0];
|
||||
if (!local || local.length === 0) return null;
|
||||
const parts = local.split(/[._-]/);
|
||||
if (parts.length >= 2) {
|
||||
const first = parts[0]?.[0];
|
||||
const second = parts[1]?.[0];
|
||||
if (first && second) {
|
||||
const result = (first + second).toUpperCase();
|
||||
if (/^[A-Z]{2}$/.test(result)) return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (local.length >= 2) {
|
||||
const result = local.slice(0, 2).toUpperCase();
|
||||
if (/^[A-Z]{2}$/.test(result)) return result;
|
||||
}
|
||||
if (local.length === 1) {
|
||||
const result = local.toUpperCase();
|
||||
if (/^[A-Z]$/.test(result)) return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -13,9 +13,6 @@ import { ALLOWED_URL_PROTOCOLS } from "./constants";
|
||||
const URI_SCHEME_REGEX = /^[a-zA-Z][a-zA-Z\d+.-]*:/;
|
||||
const BARE_EMAIL_REGEX = /^[^\s@/]+@[^\s@/:]+\.[^\s@/:]+$/;
|
||||
|
||||
export const INTERACTIVE_SELECTOR =
|
||||
"a, button, input, textarea, select, label, [role='button'], [tabindex]:not([tabindex='-1']), [contenteditable]:not([contenteditable='false'])";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
import { buildImgUrl } from "@/app/app/components/files/images/utils";
|
||||
import { OnyxIcon } from "@/components/icons/icons";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { DEFAULT_AVATAR_SIZE_PX, DEFAULT_AGENT_ID } from "@/lib/constants";
|
||||
import {
|
||||
DEFAULT_AGENT_AVATAR_SIZE_PX,
|
||||
DEFAULT_AGENT_ID,
|
||||
} from "@/lib/constants";
|
||||
import CustomAgentAvatar from "@/refresh-components/avatars/CustomAgentAvatar";
|
||||
import Image from "next/image";
|
||||
|
||||
@@ -15,7 +18,7 @@ export interface AgentAvatarProps {
|
||||
|
||||
export default function AgentAvatar({
|
||||
agent,
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
|
||||
...props
|
||||
}: AgentAvatarProps) {
|
||||
const settings = useSettingsContext();
|
||||
|
||||
@@ -4,7 +4,7 @@ import { cn } from "@/lib/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Image from "next/image";
|
||||
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import { DEFAULT_AGENT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import {
|
||||
SvgActivitySmall,
|
||||
SvgAudioEqSmall,
|
||||
@@ -96,7 +96,7 @@ export default function CustomAgentAvatar({
|
||||
src,
|
||||
iconName,
|
||||
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
|
||||
}: CustomAgentAvatarProps) {
|
||||
if (src) {
|
||||
return (
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import { SvgUser } from "@opal/icons";
|
||||
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import { getUserInitials } from "@/lib/user";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import type { User } from "@/lib/types";
|
||||
|
||||
export interface UserAvatarProps {
|
||||
user: User;
|
||||
size?: number;
|
||||
}
|
||||
|
||||
export default function UserAvatar({
|
||||
user,
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
}: UserAvatarProps) {
|
||||
const initials = getUserInitials(
|
||||
user.personalization?.name ?? null,
|
||||
user.email
|
||||
);
|
||||
|
||||
if (!initials) {
|
||||
return (
|
||||
<div
|
||||
role="img"
|
||||
aria-label={`${user.email} avatar`}
|
||||
className="flex items-center justify-center rounded-full bg-background-tint-01"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
<SvgUser size={size * 0.55} className="stroke-text-03" aria-hidden />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
role="img"
|
||||
aria-label={`${user.email} avatar`}
|
||||
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
<Text
|
||||
inverted
|
||||
secondaryAction
|
||||
text05
|
||||
className="select-none"
|
||||
style={{ fontSize: size * 0.4 }}
|
||||
>
|
||||
{initials}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -54,10 +54,8 @@ function MemoryItem({
|
||||
const wrapperRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldFocus && textareaRef.current) {
|
||||
const el = textareaRef.current;
|
||||
el.focus();
|
||||
el.selectionStart = el.selectionEnd = el.value.length;
|
||||
if (shouldFocus) {
|
||||
textareaRef.current?.focus();
|
||||
onFocused?.();
|
||||
}
|
||||
}, [shouldFocus, onFocused]);
|
||||
@@ -65,10 +63,8 @@ function MemoryItem({
|
||||
useEffect(() => {
|
||||
if (!shouldHighlight) return;
|
||||
|
||||
wrapperRef.current?.scrollIntoView({
|
||||
block: "start",
|
||||
behavior: "smooth",
|
||||
});
|
||||
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
|
||||
textareaRef.current?.focus();
|
||||
setIsHighlighting(true);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
@@ -83,10 +79,10 @@ function MemoryItem({
|
||||
<div
|
||||
ref={wrapperRef}
|
||||
className={cn(
|
||||
"rounded-08 w-full p-0.5 border border-transparent",
|
||||
"rounded-08 hover:bg-background-tint-00 w-full p-0.5",
|
||||
"transition-colors ",
|
||||
isHighlighting &&
|
||||
"bg-action-link-01 hover:bg-action-link-01 border-action-link-05 duration-700"
|
||||
"bg-action-link-01 border border-action-link-05 duration-700"
|
||||
)}
|
||||
>
|
||||
<Section gap={0.25} alignItems="start">
|
||||
@@ -104,7 +100,7 @@ function MemoryItem({
|
||||
rows={3}
|
||||
maxLength={MAX_MEMORY_LENGTH}
|
||||
resizable={false}
|
||||
className="bg-background-tint-01 hover:bg-background-tint-00 focus-within:bg-background-tint-00"
|
||||
className={cn(!isFocused && "bg-transparent")}
|
||||
/>
|
||||
<Disabled disabled={!memory.content.trim() && memory.isNew}>
|
||||
<Button
|
||||
@@ -126,29 +122,13 @@ function MemoryItem({
|
||||
);
|
||||
}
|
||||
|
||||
function resolveTargetMemoryId(
|
||||
targetMemoryId: number | null | undefined,
|
||||
targetIndex: number | null | undefined,
|
||||
memories: MemoryItem[]
|
||||
): number | null {
|
||||
if (targetMemoryId != null) return targetMemoryId;
|
||||
|
||||
if (targetIndex != null && memories.length > 0) {
|
||||
// Backend index is ASC (oldest-first), frontend displays DESC (newest-first)
|
||||
const descIdx = memories.length - 1 - targetIndex;
|
||||
return memories[descIdx]?.id ?? null;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
interface MemoriesModalProps {
|
||||
memories?: MemoryItem[];
|
||||
onSaveMemories?: (memories: MemoryItem[]) => Promise<boolean>;
|
||||
onClose?: () => void;
|
||||
initialTargetMemoryId?: number | null;
|
||||
initialTargetIndex?: number | null;
|
||||
highlightOnOpen?: boolean;
|
||||
highlightFirstOnOpen?: boolean;
|
||||
}
|
||||
|
||||
export default function MemoriesModal({
|
||||
@@ -157,7 +137,7 @@ export default function MemoriesModal({
|
||||
onClose,
|
||||
initialTargetMemoryId,
|
||||
initialTargetIndex,
|
||||
highlightOnOpen = false,
|
||||
highlightFirstOnOpen = false,
|
||||
}: MemoriesModalProps) {
|
||||
const close = useModalClose(onClose);
|
||||
const [focusMemoryId, setFocusMemoryId] = useState<number | null>(null);
|
||||
@@ -202,16 +182,24 @@ export default function MemoriesModal({
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const targetId = resolveTargetMemoryId(
|
||||
initialTargetMemoryId,
|
||||
initialTargetIndex,
|
||||
effectiveMemories
|
||||
);
|
||||
if (targetId == null) return;
|
||||
|
||||
setFocusMemoryId(targetId);
|
||||
if (highlightOnOpen) {
|
||||
setHighlightMemoryId(targetId);
|
||||
if (initialTargetMemoryId != null) {
|
||||
// Direct DB id available — use it
|
||||
setHighlightMemoryId(initialTargetMemoryId);
|
||||
} else if (initialTargetIndex != null && effectiveMemories.length > 0) {
|
||||
// Backend index is ASC (oldest-first), but the frontend displays DESC
|
||||
// (newest-first). Convert: descIdx = totalCount - 1 - ascIdx
|
||||
const descIdx = effectiveMemories.length - 1 - initialTargetIndex;
|
||||
const target = effectiveMemories[descIdx];
|
||||
if (target) {
|
||||
setHighlightMemoryId(target.id);
|
||||
}
|
||||
} else if (
|
||||
highlightFirstOnOpen &&
|
||||
effectiveMemories.length > 0 &&
|
||||
effectiveMemories[0]
|
||||
) {
|
||||
// Fallback: highlight the first displayed item (newest)
|
||||
setHighlightMemoryId(effectiveMemories[0].id);
|
||||
}
|
||||
}, [initialTargetMemoryId, initialTargetIndex]);
|
||||
|
||||
|
||||
@@ -148,12 +148,14 @@ interface ExistingProviderCardProps {
|
||||
provider: LLMProviderView;
|
||||
isDefault: boolean;
|
||||
isLastProvider: boolean;
|
||||
defaultModelName?: string;
|
||||
}
|
||||
|
||||
function ExistingProviderCard({
|
||||
provider,
|
||||
isDefault,
|
||||
isLastProvider,
|
||||
defaultModelName,
|
||||
}: ExistingProviderCardProps) {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
@@ -230,7 +232,12 @@ function ExistingProviderCard({
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
{getModalForExistingProvider(provider, isOpen, setIsOpen)}
|
||||
{getModalForExistingProvider(
|
||||
provider,
|
||||
isOpen,
|
||||
setIsOpen,
|
||||
defaultModelName
|
||||
)}
|
||||
</Card>
|
||||
</Hoverable.Root>
|
||||
</>
|
||||
@@ -446,6 +453,11 @@ export default function LLMConfigurationPage() {
|
||||
provider={provider}
|
||||
isDefault={defaultText?.provider_id === provider.id}
|
||||
isLastProvider={sortedProviders.length === 1}
|
||||
defaultModelName={
|
||||
defaultText?.provider_id === provider.id
|
||||
? defaultText.model_name
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -26,8 +26,7 @@ import type {
|
||||
StatusFilter,
|
||||
StatusCountMap,
|
||||
} from "./interfaces";
|
||||
import UserAvatar from "@/refresh-components/avatars/UserAvatar";
|
||||
import type { User } from "@/lib/types";
|
||||
import { getInitials } from "./utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Column renderers
|
||||
@@ -76,25 +75,20 @@ const tc = createTableColumns<UserRow>();
|
||||
function buildColumns(onMutate: () => void) {
|
||||
return [
|
||||
tc.qualifier({
|
||||
content: "icon",
|
||||
getContent: (row) => {
|
||||
const user = {
|
||||
email: row.email,
|
||||
personalization: row.personal_name
|
||||
? { name: row.personal_name }
|
||||
: undefined,
|
||||
} as User;
|
||||
return (props) => <UserAvatar user={user} size={props.size} />;
|
||||
},
|
||||
content: "avatar-user",
|
||||
getInitials: (row) => getInitials(row.personal_name, row.email),
|
||||
selectable: false,
|
||||
}),
|
||||
tc.column("email", {
|
||||
header: "Name",
|
||||
weight: 22,
|
||||
minWidth: 140,
|
||||
cell: renderNameColumn,
|
||||
}),
|
||||
tc.column("groups", {
|
||||
header: "Groups",
|
||||
weight: 24,
|
||||
minWidth: 200,
|
||||
enableSorting: false,
|
||||
cell: (value, row) => (
|
||||
<GroupsCell groups={value} user={row} onMutate={onMutate} />
|
||||
@@ -103,16 +97,19 @@ function buildColumns(onMutate: () => void) {
|
||||
tc.column("role", {
|
||||
header: "Account Type",
|
||||
weight: 16,
|
||||
minWidth: 180,
|
||||
cell: (_value, row) => <UserRoleCell user={row} onMutate={onMutate} />,
|
||||
}),
|
||||
tc.column("status", {
|
||||
header: "Status",
|
||||
weight: 14,
|
||||
minWidth: 100,
|
||||
cell: renderStatusColumn,
|
||||
}),
|
||||
tc.column("updated_at", {
|
||||
header: "Last Updated",
|
||||
weight: 14,
|
||||
minWidth: 100,
|
||||
cell: renderLastUpdatedColumn,
|
||||
}),
|
||||
tc.actions({
|
||||
@@ -222,6 +219,7 @@ export default function UsersTable({
|
||||
data={filteredUsers}
|
||||
columns={columns}
|
||||
getRowId={(row) => row.id ?? row.email}
|
||||
qualifier="avatar"
|
||||
pageSize={PAGE_SIZE}
|
||||
searchTerm={searchTerm}
|
||||
emptyState={
|
||||
|
||||
43
web/src/refresh-pages/admin/UsersPage/utils.test.ts
Normal file
43
web/src/refresh-pages/admin/UsersPage/utils.test.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { getInitials } from "./utils";
|
||||
|
||||
describe("getInitials", () => {
|
||||
it("returns first letters of first two name parts", () => {
|
||||
expect(getInitials("Alice Smith", "alice@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("returns first two chars of a single-word name", () => {
|
||||
expect(getInitials("Alice", "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("handles three-word names (uses first two)", () => {
|
||||
expect(getInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
|
||||
});
|
||||
|
||||
it("falls back to email local part with dot separator", () => {
|
||||
expect(getInitials(null, "alice.smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with underscore separator", () => {
|
||||
expect(getInitials(null, "alice_smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("falls back to email local part with hyphen separator", () => {
|
||||
expect(getInitials(null, "alice-smith@example.com")).toBe("AS");
|
||||
});
|
||||
|
||||
it("uses first two chars of email local if no separator", () => {
|
||||
expect(getInitials(null, "alice@example.com")).toBe("AL");
|
||||
});
|
||||
|
||||
it("returns ? for empty email local part", () => {
|
||||
expect(getInitials(null, "@example.com")).toBe("?");
|
||||
});
|
||||
|
||||
it("uppercases the result", () => {
|
||||
expect(getInitials("john doe", "jd@test.com")).toBe("JD");
|
||||
});
|
||||
|
||||
it("trims whitespace from name", () => {
|
||||
expect(getInitials(" Alice Smith ", "a@test.com")).toBe("AS");
|
||||
});
|
||||
});
|
||||
23
web/src/refresh-pages/admin/UsersPage/utils.ts
Normal file
23
web/src/refresh-pages/admin/UsersPage/utils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* Derive display initials from a user's name or email.
|
||||
*
|
||||
* - If a name is provided, uses the first letter of the first two words.
|
||||
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
|
||||
* - Returns at most 2 uppercase characters.
|
||||
*/
|
||||
export function getInitials(name: string | null, email: string): string {
|
||||
if (name) {
|
||||
const parts = name.trim().split(/\s+/);
|
||||
if (parts.length >= 2) {
|
||||
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
|
||||
}
|
||||
return name.slice(0, 2).toUpperCase();
|
||||
}
|
||||
const local = email.split("@")[0];
|
||||
if (!local) return "?";
|
||||
const parts = local.split(/[._-]/);
|
||||
if (parts.length >= 2) {
|
||||
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
|
||||
}
|
||||
return local.slice(0, 2).toUpperCase();
|
||||
}
|
||||
@@ -35,6 +35,7 @@ export default function AnthropicModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -64,10 +65,15 @@ export default function AnthropicModal({
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
default_model_name:
|
||||
defaultModelName ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
|
||||
@@ -81,6 +81,7 @@ export default function AzureModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -109,7 +110,11 @@ export default function AzureModal({
|
||||
default_model_name: "",
|
||||
} as AzureModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
};
|
||||
|
||||
@@ -315,6 +315,7 @@ export default function BedrockModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -351,7 +352,11 @@ export default function BedrockModal({
|
||||
},
|
||||
} as BedrockModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??
|
||||
|
||||
@@ -197,6 +197,7 @@ export default function CustomModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
}: LLMProviderFormProps) {
|
||||
@@ -209,7 +210,11 @@ export default function CustomModal({
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = {
|
||||
...buildDefaultInitialValues(existingLlmProvider),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
undefined,
|
||||
defaultModelName
|
||||
),
|
||||
...(isOnboarding ? buildOnboardingInitialValues() : {}),
|
||||
provider: existingLlmProvider?.provider ?? "",
|
||||
model_configurations: existingLlmProvider?.model_configurations.map(
|
||||
|
||||
@@ -192,6 +192,7 @@ export default function LMStudioForm({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -225,7 +226,11 @@ export default function LMStudioForm({
|
||||
},
|
||||
} as LMStudioFormValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY:
|
||||
|
||||
@@ -159,6 +159,7 @@ export default function LiteLLMProxyModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -190,7 +191,11 @@ export default function LiteLLMProxyModal({
|
||||
default_model_name: "",
|
||||
} as LiteLLMProxyModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
@@ -212,6 +212,7 @@ export default function OllamaModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -244,7 +245,11 @@ export default function OllamaModal({
|
||||
},
|
||||
} as OllamaModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY:
|
||||
|
||||
@@ -35,6 +35,7 @@ export default function OpenAIModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -63,9 +64,14 @@ export default function OpenAIModal({
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
defaultModelName ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
|
||||
@@ -158,6 +158,7 @@ export default function OpenRouterModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -189,7 +190,11 @@ export default function OpenRouterModal({
|
||||
default_model_name: "",
|
||||
} as OpenRouterModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
@@ -48,6 +48,7 @@ export default function VertexAIModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
@@ -80,8 +81,13 @@ export default function VertexAIModal({
|
||||
},
|
||||
} as VertexAIModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(existingLlmProvider, modelConfigurations),
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
default_model_name:
|
||||
defaultModelName ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
|
||||
@@ -22,9 +22,15 @@ function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
export function getModalForExistingProvider(
|
||||
provider: LLMProviderView,
|
||||
open?: boolean,
|
||||
onOpenChange?: (open: boolean) => void
|
||||
onOpenChange?: (open: boolean) => void,
|
||||
defaultModelName?: string
|
||||
) {
|
||||
const props = { existingLlmProvider: provider, open, onOpenChange };
|
||||
const props = {
|
||||
existingLlmProvider: provider,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
};
|
||||
|
||||
switch (provider.provider) {
|
||||
case LLMProviderName.OPENAI:
|
||||
|
||||
@@ -12,9 +12,11 @@ export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
|
||||
export const buildDefaultInitialValues = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
modelConfigurations?: ModelConfiguration[]
|
||||
modelConfigurations?: ModelConfiguration[],
|
||||
currentDefaultModelName?: string
|
||||
) => {
|
||||
const defaultModelName =
|
||||
currentDefaultModelName ??
|
||||
existingLlmProvider?.model_configurations?.[0]?.name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
@@ -182,7 +182,7 @@ const ProjectFolderButton = memo(({ project }: ProjectFolderButtonProps) => {
|
||||
onClose={() => setIsEditing(false)}
|
||||
/>
|
||||
) : (
|
||||
<Truncated text03>{project.name}</Truncated>
|
||||
<Truncated>{project.name}</Truncated>
|
||||
)}
|
||||
</SidebarTab>
|
||||
</Popover.Anchor>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user