Compare commits

..

50 Commits

Author SHA1 Message Date
Evan Lohn
068ac543ad fix: deadlock in multitenant test (#9530) 2026-03-20 23:05:20 +00:00
Bo-Onyx
30e7a831a5 feat(hook): Add hook management API (#9513)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 22:53:59 +00:00
Evan Lohn
276261c96d fix: windows installer (#9507) 2026-03-20 22:53:46 +00:00
Bo-Onyx
205f1410e4 chore(hook): Hook executor. (#9467) 2026-03-20 22:47:01 +00:00
Bo-Onyx
a93d154c27 feat(hook): improve on hook point definition (#9522) 2026-03-20 22:20:42 +00:00
Jamison Lahman
1361879bd0 fix(fe): clicking outside chat area keeps chat input focused (#9521) 2026-03-20 19:22:11 +00:00
Justin Tahara
c58cc320b2 feat(tf): Port over WAF updates (#9520) 2026-03-20 18:45:09 +00:00
Jamison Lahman
461350958a fix(fe): dim project name in sidebar color (#9519) 2026-03-20 17:47:49 +00:00
Raunak Bhagat
50dde0be1a chore: edit AGENTS.md and CLAUDE.md files (#9486) 2026-03-20 00:59:30 +00:00
acaprau
199e1df453 feat(opensearch): Add functions for keyword and semantic retrieval (#9479)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-20 00:48:01 +00:00
Justin Tahara
996b674840 feat(backend): Adding procps (#9509) 2026-03-19 23:26:36 +00:00
Justin Tahara
5413723ccc feat(ods): Rerun run-ci workflow (#9501) 2026-03-19 22:11:59 +00:00
Evan Lohn
9660056a51 fix: drive rate limit retry (#9498) 2026-03-19 21:32:08 +00:00
Fizza Mukhtar
3105177238 fix(llm): don't send tool_choice when no tools are provided (#9224) 2026-03-19 21:26:46 +00:00
Evan Lohn
24bb4bda8b feat: windows installer and install improvements (#9476) 2026-03-19 20:47:44 +00:00
Raunak Bhagat
9532af4ceb chore: move Hoverable story (#9495) 2026-03-19 20:40:27 +00:00
Jamison Lahman
0a913f6af5 fix(fe): fix memories immediately losing focus on click (#9493) 2026-03-19 20:15:34 +00:00
Justin Tahara
fe30c55199 fix(code interpreter): Caching files (#9484) 2026-03-19 19:32:37 +00:00
Jamison Lahman
2cf0a65dd3 chore(fe): reduce padding on elements at the bottom of modal headers (#9488) 2026-03-19 19:27:37 +00:00
Nikolas Garza
659416f363 feat(admin): groups page - list page and group cards (#9453) 2026-03-19 18:23:15 +00:00
Raunak Bhagat
40aecbc4b9 refactor(fe): move table to opal, update size API (#9438) 2026-03-19 17:23:41 +00:00
Jamison Lahman
710b39074f chore(fe): remove opal-button* class names (#9471) 2026-03-19 02:15:00 +00:00
acaprau
8fe2f67d38 chore(opensearch): Allow disabling match highlights via env var; default to disabled (#9436) 2026-03-19 00:43:17 +00:00
Justin Tahara
f00aaf9fc0 fix(agents): Agents are Private by Default (#9465) 2026-03-19 00:01:46 +00:00
Bo-Onyx
5b2426b002 chore(hooks): Define Hook Point in the backend (#9391) 2026-03-18 23:43:26 +00:00
Justin Tahara
ba6ab0245b fix(celery): add dedup guardrails to user file delete queue (#9454)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 23:38:52 +00:00
Justin Tahara
b64ebb57e1 fix(logging): extract LiteLLM error details in image summarization failures (#9458) 2026-03-18 23:29:04 +00:00
Justin Tahara
2fcfdbabde fix(celery): add task expiry to upload API send_task call (#9456)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 23:17:08 +00:00
Justin Tahara
ea1a2749c1 fix(image): add diagnostic logging to vision model selection (#9460) 2026-03-18 22:06:56 +00:00
Justin Tahara
73c4e22588 fix(image): stop dumping base64 image data into error logs (#9457) 2026-03-18 21:43:55 +00:00
Jamison Lahman
fceaac6e13 fix(fe): make indexing attempt error rows click to show trace (#9463)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-18 21:38:53 +00:00
Jamison Lahman
e8bf45cfd2 feat(fe): "Full Exception Trace" modal uses CodePreview rendering (#9464)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-18 21:04:55 +00:00
Bo-Onyx
13ff648fcd chore(hooks): Add Celery task to remove hook running records older than 30 days (#9433) 2026-03-18 21:03:01 +00:00
Jamison Lahman
ae8268afb1 fix(fe): truncate connector names in table (#9459) 2026-03-18 20:59:49 +00:00
acaprau
b338bd9e97 feat(opensearch): Can override number of shards and replicas via env var (#9431) 2026-03-18 20:16:05 +00:00
acaprau
0dcc90a042 fix(opensearch): Exclude retrieving vectors during hybrid and random search (#9430) 2026-03-18 20:13:12 +00:00
Jamison Lahman
0f6a6693d3 fix(fe): truncate project name in sidebar button (#9462) 2026-03-18 20:06:09 +00:00
Jamison Lahman
e32cc450b2 fix(fe): update connector indexing error modal (#9426)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-18 11:57:28 -07:00
Jamison Lahman
732fb71edf chore(tests): unit tests for pdf processing (#9452)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 18:31:37 +00:00
dependabot[bot]
ca3320c0e0 chore(deps): bump pypdf from 6.8.0 to 6.9.1 (#9450)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-18 17:52:50 +00:00
Jamison Lahman
d7c554aca7 chore(ruff): fix and enable S324 (#9451) 2026-03-18 17:26:29 +00:00
dependabot[bot]
69e5c19695 chore(deps): bump next from 16.1.5 to 16.1.7 in /examples/widget (#9425)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-18 09:25:27 -07:00
Nikolas Garza
b4ce1c7a97 chore: bump next to 16.1.7 (#9423) 2026-03-18 09:22:40 -07:00
Jamison Lahman
cd64a91154 fix(fe): display name on attachment file card hover (#9446) 2026-03-18 16:13:21 +00:00
Danelegend
c282cdc096 fix(file upload): Allow zip file upload via query param (#9432) 2026-03-18 07:32:07 +00:00
Jamison Lahman
b1de1c59b6 chore(playwright): projects screenshot is main container only (#9440) 2026-03-18 05:35:30 +00:00
acaprau
64d484039f chore(opensearch): Disable test_update_single_can_clear_user_projects_and_personas (#9434) 2026-03-18 00:40:29 +00:00
Jamison Lahman
0530095b71 fix(fe): replace users table buttons with LineItems (#9435) 2026-03-17 23:45:15 +00:00
acaprau
23280d5b91 fix(opensearch): Fix env var mismatch issue with configuring subquery results; set default to 100 (#9428) 2026-03-17 16:01:45 -07:00
Bo-Onyx
229442679c chore(hooks): Add db CRUD (#9411) 2026-03-17 22:36:50 +00:00
179 changed files with 9102 additions and 3855 deletions

279
AGENTS.md
View File

@@ -167,284 +167,7 @@ web/
## Frontend Standards
### 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
### 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
### 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
### 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
### 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
### 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
### 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
### 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
### 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
### 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
### 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
### 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
## Database & Migrations

View File

@@ -47,6 +47,8 @@ RUN apt-get update && \
gcc \
nano \
vim \
# Install procps so kubernetes exec sessions can use ps aux for debugging
procps \
libjemalloc2 \
&& \
rm -rf /var/lib/apt/lists/* && \

View File

@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -0,0 +1,35 @@
from celery import shared_task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import cleanup_old_execution_logs__no_commit
from onyx.utils.logger import setup_logger
logger = setup_logger()
_HOOK_EXECUTION_LOG_RETENTION_DAYS: int = 30
@shared_task(
name=OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def hook_execution_log_cleanup_task(*, tenant_id: str) -> None: # noqa: ARG001
try:
with get_session_with_current_tenant() as db_session:
deleted: int = cleanup_old_execution_logs__no_commit(
db_session=db_session,
max_age_days=_HOOK_EXECUTION_LOG_RETENTION_DAYS,
)
db_session.commit()
if deleted:
logger.info(
f"Deleted {deleted} hook execution log(s) older than "
f"{_HOOK_EXECUTION_LOG_RETENTION_DAYS} days."
)
except Exception:
logger.exception("Failed to clean up hook execution logs")
raise

View File

@@ -24,6 +24,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
@@ -33,6 +34,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
@@ -91,6 +93,17 @@ def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a delete_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
@@ -546,7 +559,23 @@ def process_single_user_file(
ignore_result=True,
)
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with DELETING status and enqueue per-file tasks."""
"""Scan for user files with DELETING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway (mirrors check_user_file_processing):
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still DELETING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
"""
task_logger.info("check_for_user_file_delete - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -555,8 +584,23 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
)
if not lock.acquire(blocking=False):
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_for_user_file_delete - Queue depth {queue_len} exceeds "
f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -568,23 +612,40 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_delete_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
try:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
except Exception as e:
task_logger.exception(
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
)
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}"
)
return None
@@ -602,6 +663,9 @@ def delete_user_file_impl(
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the queued guard so the beat can re-enqueue if deletion fails
# and the file remains in DELETING status.
redis_client.delete(_user_file_delete_queued_key(user_file_id))
file_lock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,

View File

@@ -297,7 +297,9 @@ class PostgresCacheBackend(CacheBackend):
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
h = hashlib.md5(
f"{self._tenant_id}:{name}".encode(), usedforsecurity=False
).digest()
return struct.unpack("q", h[:8])[0]

View File

@@ -278,14 +278,17 @@ USING_AWS_MANAGED_OPENSEARCH = (
OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# Whether to disable match highlights for OpenSearch. Defaults to True for now
# as we investigate query performance.
OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED = (
os.environ.get("OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED", "true").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
@@ -318,8 +321,16 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
)
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
# If set, will override the default number of shards and replicas for the index.
OPENSEARCH_INDEX_NUM_SHARDS: int | None = (
int(os.environ["OPENSEARCH_INDEX_NUM_SHARDS"])
if os.environ.get("OPENSEARCH_INDEX_NUM_SHARDS", None) is not None
else None
)
OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
int(os.environ["OPENSEARCH_INDEX_NUM_REPLICAS"])
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
else None
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
@@ -776,9 +787,6 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The number of chunks in an indexing batch
CHUNKS_PER_BATCH = 1000
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"

View File

@@ -177,6 +177,14 @@ USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
# How long a queued user-file-delete task is valid before workers discard it.
# Mirrors the processing task expiry to prevent indefinite queue growth when
# files are stuck in DELETING status and the beat keeps re-enqueuing them.
CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before the delete beat stops enqueuing more delete tasks.
USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -469,6 +477,9 @@ class OnyxRedisLocks:
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
# Short-lived key set when a delete task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a delete task is already queued.
USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued"
# Release notes
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
@@ -597,6 +608,9 @@ class OnyxCeleryTask:
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
# Hook execution log retention
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
# Sandbox cleanup
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"

View File

@@ -157,9 +157,7 @@ def _execute_single_retrieval(
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
results = _execute_with_retry(retrieval_function(**request_kwargs))
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")

233
backend/onyx/db/hook.py Normal file
View File

@@ -0,0 +1,233 @@
import datetime
from uuid import UUID
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.models import Hook
from onyx.db.models import HookExecutionLog
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ── Hook CRUD ────────────────────────────────────────────────────────────
def get_hook_by_id(
*,
db_session: Session,
hook_id: int,
include_deleted: bool = False,
include_creator: bool = False,
) -> Hook | None:
stmt = select(Hook).where(Hook.id == hook_id)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_non_deleted_hook_by_hook_point(
*,
db_session: Session,
hook_point: HookPoint,
include_creator: bool = False,
) -> Hook | None:
stmt = (
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
)
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_hooks(
*,
db_session: Session,
include_deleted: bool = False,
include_creator: bool = False,
) -> list[Hook]:
stmt = select(Hook)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
return list(db_session.scalars(stmt).all())
def create_hook__no_commit(
*,
db_session: Session,
name: str,
hook_point: HookPoint,
endpoint_url: str | None = None,
api_key: str | None = None,
fail_strategy: HookFailStrategy,
timeout_seconds: float,
is_active: bool = False,
creator_id: UUID | None = None,
) -> Hook:
"""Create a new hook for the given hook point.
At most one non-deleted hook per hook point is allowed. Raises
OnyxError(CONFLICT) if a hook already exists, including under concurrent
duplicate creates where the partial unique index fires an IntegrityError.
"""
existing = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if existing:
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
)
hook = Hook(
name=name,
hook_point=hook_point,
endpoint_url=endpoint_url,
api_key=api_key,
fail_strategy=fail_strategy,
timeout_seconds=timeout_seconds,
is_active=is_active,
creator_id=creator_id,
)
# Use a savepoint so that a failed insert only rolls back this operation,
# not the entire outer transaction.
savepoint = db_session.begin_nested()
try:
db_session.add(hook)
savepoint.commit()
except IntegrityError as exc:
savepoint.rollback()
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists.",
)
raise # re-raise unrelated integrity errors (FK violations, etc.)
return hook
def update_hook__no_commit(
*,
db_session: Session,
hook_id: int,
name: str | None = None,
endpoint_url: str | None | UnsetType = UNSET,
api_key: str | None | UnsetType = UNSET,
fail_strategy: HookFailStrategy | None = None,
timeout_seconds: float | None = None,
is_active: bool | None = None,
is_reachable: bool | None = None,
include_creator: bool = False,
) -> Hook:
"""Update hook fields.
Sentinel conventions:
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
"""
hook = get_hook_by_id(
db_session=db_session, hook_id=hook_id, include_creator=include_creator
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
if name is not None:
hook.name = name
if not isinstance(endpoint_url, UnsetType):
hook.endpoint_url = endpoint_url
if not isinstance(api_key, UnsetType):
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
if fail_strategy is not None:
hook.fail_strategy = fail_strategy
if timeout_seconds is not None:
hook.timeout_seconds = timeout_seconds
if is_active is not None:
hook.is_active = is_active
if is_reachable is not None:
hook.is_reachable = is_reachable
db_session.flush()
return hook
def delete_hook__no_commit(
*,
db_session: Session,
hook_id: int,
) -> None:
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
hook.deleted = True
hook.is_active = False
db_session.flush()
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
def create_hook_execution_log__no_commit(
*,
db_session: Session,
hook_id: int,
is_success: bool,
error_message: str | None = None,
status_code: int | None = None,
duration_ms: int | None = None,
) -> HookExecutionLog:
log = HookExecutionLog(
hook_id=hook_id,
is_success=is_success,
error_message=error_message,
status_code=status_code,
duration_ms=duration_ms,
)
db_session.add(log)
db_session.flush()
return log
def get_hook_execution_logs(
*,
db_session: Session,
hook_id: int,
limit: int,
) -> list[HookExecutionLog]:
stmt = (
select(HookExecutionLog)
.where(HookExecutionLog.hook_id == hook_id)
.order_by(HookExecutionLog.created_at.desc())
.limit(limit)
)
return list(db_session.scalars(stmt).all())
def cleanup_old_execution_logs__no_commit(
*,
db_session: Session,
max_age_days: int,
) -> int:
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=max_age_days
)
result: CursorResult = db_session.execute( # type: ignore[assignment]
delete(HookExecutionLog)
.where(HookExecutionLog.created_at < cutoff)
.execution_options(synchronize_session=False)
)
return result.rowcount

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -144,6 +145,7 @@ def upload_files_to_user_files_with_indexing(
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"

View File

@@ -2,6 +2,7 @@ import time
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
if DISABLE_VECTOR_DB:
return None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)

View File

@@ -5,7 +5,6 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -67,7 +66,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -207,7 +206,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -227,8 +226,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -210,10 +209,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from contextlib import AbstractContextManager
@@ -18,6 +19,7 @@ from onyx.configs.app_configs import OPENSEARCH_HOST
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
from onyx.utils.logger import setup_logger
@@ -56,8 +58,8 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
# Score explanation from OpenSearch when "explain": true is set in the
# query. Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
@@ -833,9 +835,13 @@ class OpenSearchIndexClient(OpenSearchClient):
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
) -> list[SearchHit[DocumentChunk]]:
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
"""Searches the index.
NOTE: Does not return vector fields. In order to take advantage of
performance benefits, the search body should exclude the schema's vector
fields.
TODO(andrei): Ideally we could check that every field in the body is
present in the index, to avoid a class of runtime bugs that could easily
be caught during development. Or change the function signature to accept
@@ -883,7 +889,7 @@ class OpenSearchIndexClient(OpenSearchClient):
raise_on_timeout=True,
)
search_hits: list[SearchHit[DocumentChunk]] = []
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
for hit in hits:
document_chunk_source: dict[str, Any] | None = hit.get("_source")
if not document_chunk_source:
@@ -893,8 +899,10 @@ class OpenSearchIndexClient(OpenSearchClient):
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
search_hit = SearchHit[DocumentChunkWithoutVectors](
document_chunk=DocumentChunkWithoutVectors.model_validate(
document_chunk_source
),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
@@ -1055,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
f"Body: {get_new_body_without_vectors(body)}\n"
f"Search pipeline ID: {search_pipeline_id}\n"
f"Phase took: {phase_took}\n"
f"Profile: {profile}\n"
f"Profile: {json.dumps(profile, indent=2)}\n"
)
if timed_out:
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."

View File

@@ -3,10 +3,6 @@
import os
from enum import Enum
from onyx.configs.app_configs import (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
DEFAULT_MAX_CHUNK_SIZE = 512
@@ -41,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
# we have a much higher chance of all 10 of the final desired docs showing up
# and getting scored. In worse situations, the final 10 docs don't even show up
# as the final 10 (worse than just a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
else 750
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
# poor search performance.
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
)
# Number of vectors to examine to decide the top k neighbors for the HNSW
@@ -54,7 +50,7 @@ DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
# larger than k, you can provide the size parameter to limit the final number of
# results to k." from
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
class HybridSearchSubqueryConfiguration(Enum):

View File

@@ -1,12 +1,11 @@
import json
from collections.abc import Iterable
from collections import defaultdict
from typing import Any
import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import CHUNKS_PER_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -48,6 +47,7 @@ from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
@@ -118,7 +118,7 @@ def set_cluster_state(client: OpenSearchClient) -> None:
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
chunk: DocumentChunkWithoutVectors,
score: float | None,
highlights: dict[str, list[str]],
) -> InferenceChunkUncleaned:
@@ -350,7 +350,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -646,8 +646,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
@@ -672,32 +672,29 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
]
onyx_document: Document = doc_chunks[0].source_document
onyx_document: Document = chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -706,43 +703,22 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= CHUNKS_PER_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
document_indexing_results.append(document_insertion_record)
return document_indexing_results
@@ -905,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
results: list[InferenceChunk] = []
for chunk_request in chunk_requests:
search_hits: list[SearchHit[DocumentChunk]] = []
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
query_body = DocumentQuery.get_from_document_id_query(
document_id=chunk_request.document_id,
tenant_state=self._tenant_state,
@@ -969,12 +945,91 @@ class OpenSearchDocumentIndex(DocumentIndex):
include_hidden=False,
)
normalization_pipeline_name, _ = get_normalization_pipeline_name_and_config()
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=normalization_pipeline_name,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
# Good place for a breakpoint to inspect the search hits if you have
# "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_keyword_search_query(
query_text=query,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_embedding,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
@@ -1001,7 +1056,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)

View File

@@ -11,6 +11,8 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_REPLICAS
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_SHARDS
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.document_index.interfaces_new import TenantState
@@ -100,9 +102,9 @@ def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
return value
class DocumentChunk(BaseModel):
class DocumentChunkWithoutVectors(BaseModel):
"""
Represents a chunk of a document in the OpenSearch index.
Represents a chunk of a document in the OpenSearch index without vectors.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
@@ -124,9 +126,7 @@ class DocumentChunk(BaseModel):
# Either both should be None or both should be non-None.
title: str | None = None
title_vector: list[float] | None = None
content: str
content_vector: list[float]
source_type: str
# A list of key-value pairs separated by INDEX_SEPARATOR. See
@@ -176,19 +176,9 @@ class DocumentChunk(BaseModel):
def __str__(self) -> str:
return (
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
f"tenant_id={self.tenant_id.tenant_id})"
f"content length={len(self.content)}, tenant_id={self.tenant_id.tenant_id})."
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
# title and title_vector should both either be None or not.
if self.title is not None and self.title_vector is None:
raise ValueError("Bug: Title vector must not be None if title is not None.")
if self.title_vector is not None and self.title is None:
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
@model_serializer(mode="wrap")
def serialize_model(
self, handler: SerializerFunctionWrapHandler
@@ -305,6 +295,35 @@ class DocumentChunk(BaseModel):
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
class DocumentChunk(DocumentChunkWithoutVectors):
"""Represents a chunk of a document in the OpenSearch index.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
"""
model_config = {"frozen": True}
title_vector: list[float] | None = None
content_vector: list[float]
def __str__(self) -> str:
return (
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
f"tenant_id={self.tenant_id.tenant_id})"
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
# title and title_vector should both either be None or not.
if self.title is not None and self.title_vector is None:
raise ValueError("Bug: Title vector must not be None if title is not None.")
if self.title_vector is not None and self.title is None:
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
class DocumentSchema:
"""
Represents the schema and indexing strategies of the OpenSearch index.
@@ -516,78 +535,35 @@ class DocumentSchema:
return schema
@staticmethod
def get_index_settings() -> dict[str, Any]:
"""
Standard settings for reasonable local index and search performance.
"""
return {
"index": {
"number_of_shards": 1,
"number_of_replicas": 1,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch.
Our AWS-managed OpenSearch cluster has 3 data nodes in 3 availability
zones.
- We use 3 shards to distribute load across all data nodes.
- We use 2 replicas to ensure each shard has a copy in each
availability zone. This is a hard requirement from AWS. The number
of data copies, including the primary (not a replica) copy, must be
divisible by the number of AZs.
"""
return {
"index": {
"number_of_shards": 3,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch in multi-tenant cloud.
324 shards very roughly targets a storage load of ~30Gb per shard, which
according to AWS OpenSearch documentation is within a good target range.
As documented above we need 2 replicas for a total of 3 copies of the
data because the cluster is configured with 3-AZ awareness.
"""
return {
"index": {
"number_of_shards": 324,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_based_on_environment() -> dict[str, Any]:
"""
Returns the index settings based on the environment.
"""
if USING_AWS_MANAGED_OPENSEARCH:
# NOTE: The number of data copies, including the primary (not a
# replica) copy, must be divisible by the number of AZs.
if MULTI_TENANT:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
)
number_of_shards = 324
number_of_replicas = 2
else:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
)
number_of_shards = 3
number_of_replicas = 2
else:
return DocumentSchema.get_index_settings()
number_of_shards = 1
number_of_replicas = 1
if OPENSEARCH_INDEX_NUM_SHARDS is not None:
number_of_shards = OPENSEARCH_INDEX_NUM_SHARDS
if OPENSEARCH_INDEX_NUM_REPLICAS is not None:
number_of_replicas = OPENSEARCH_INDEX_NUM_REPLICAS
return {
"index": {
"number_of_shards": number_of_shards,
"number_of_replicas": number_of_replicas,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
@@ -15,7 +16,7 @@ from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import ASSUMED_DOCUMENT_AGE_DAYS
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
)
from onyx.document_index.opensearch.constants import (
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
@@ -235,9 +236,17 @@ class DocumentQuery:
# returning some number of results less than the index max allowed
# return size.
"size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
"_source": get_full_document,
# By default exclude retrieving the vector fields in order to save
# on retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
if not get_full_document:
# If we explicitly do not want the underlying document, we will only
# retrieve IDs.
final_get_ids_query["_source"] = False
if not OPENSEARCH_PROFILING_DISABLED:
final_get_ids_query["profile"] = True
@@ -332,7 +341,7 @@ class DocumentQuery:
# TODO(andrei, yuhong): We can tune this more dynamically based on
# num_hits.
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
max_results_per_subquery = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, vector_candidates=max_results_per_subquery
@@ -356,9 +365,6 @@ class DocumentQuery:
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
match_highlights_configuration = (
DocumentQuery._get_match_highlights_configuration()
)
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
hybrid_search_query: dict[str, Any] = {
@@ -385,16 +391,183 @@ class DocumentQuery:
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
# Explain is for scoring breakdowns.
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_hybrid_search_body["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
return final_hybrid_search_body
@staticmethod
def get_keyword_search_query(
query_text: str,
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final keyword search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_text: The text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the keyword search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final keyword search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
keyword_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
keyword_search_query = (
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text, search_filters=keyword_search_filters
)
)
final_keyword_search_query: dict[str, Any] = {
"query": keyword_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_keyword_search_query["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
if not OPENSEARCH_PROFILING_DISABLED:
final_keyword_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_keyword_search_query["explain"] = True
return final_keyword_search_query
@staticmethod
def get_semantic_search_query(
query_embedding: list[float],
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final semantic search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_embedding: The vector embedding of the text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the semantic search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final semantic search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
semantic_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
semantic_search_query = (
DocumentQuery._get_content_vector_similarity_search_query(
query_embedding,
vector_candidates=num_hits,
search_filters=semantic_search_filters,
)
)
final_semantic_search_query: dict[str, Any] = {
"query": semantic_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_semantic_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_semantic_search_query["explain"] = True
return final_semantic_search_query
@staticmethod
def get_random_search_query(
tenant_state: TenantState,
@@ -446,6 +619,11 @@ class DocumentQuery:
},
"size": num_to_retrieve,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_random_search_query["profile"] = True
@@ -460,7 +638,7 @@ class DocumentQuery:
# search. This is higher than the number of results because the scoring
# is hybrid. For a detailed breakdown, see where the default value is
# set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -546,7 +724,7 @@ class DocumentQuery:
@staticmethod
def _get_title_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
) -> dict[str, Any]:
return {
"knn": {
@@ -560,9 +738,10 @@ class DocumentQuery:
@staticmethod
def _get_content_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
return {
query = {
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
@@ -571,11 +750,19 @@ class DocumentQuery:
}
}
if search_filters is not None:
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
"bool": {"filter": search_filters}
}
return query
@staticmethod
def _get_title_content_combined_keyword_search_query(
query_text: str,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
return {
query = {
"bool": {
"should": [
{
@@ -616,10 +803,19 @@ class DocumentQuery:
}
}
},
]
],
# Ensure at least one term from the query is present in the
# document. This defaults to 1, unless a filter or must clause
# is supplied, in which case it defaults to 0.
"minimum_should_match": 1,
}
}
if search_filters is not None:
query["bool"]["filter"] = search_filters
return query
@staticmethod
def _get_search_filters(
tenant_state: TenantState,

View File

@@ -6,7 +6,6 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -462,7 +461,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,8 +1,6 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -10,7 +8,6 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import CHUNKS_PER_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -321,7 +318,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -341,31 +338,22 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
existing_docs: set[str] = set()
@@ -421,16 +409,8 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, CHUNKS_PER_BATCH)
):
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
@@ -439,6 +419,10 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -88,6 +88,7 @@ class OnyxErrorCode(Enum):
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:

View File

@@ -88,9 +88,13 @@ def summarize_image_with_error_handling(
try:
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
except UnsupportedImageFormatError:
magic_hex = image_data[:8].hex() if image_data else "empty"
logger.info(
"Skipping image summarization due to unsupported MIME type for %s",
"Skipping image summarization due to unsupported MIME type "
"for %s (magic_bytes=%s, size=%d bytes)",
context_name,
magic_hex,
len(image_data),
)
return None
@@ -134,9 +138,23 @@ def _summarize_image(
return summary
except Exception as e:
error_msg = f"Summarization failed. Messages: {messages}"
error_msg = error_msg[:1024]
raise ValueError(error_msg) from e
# Extract structured details from LiteLLM exceptions when available,
# rather than dumping the full messages payload (which contains base64
# image data and produces enormous, unreadable error logs).
str_e = str(e)
if len(str_e) > 512:
str_e = str_e[:512] + "... (truncated)"
parts = [f"Summarization failed: {type(e).__name__}: {str_e}"]
status_code = getattr(e, "status_code", None)
llm_provider = getattr(e, "llm_provider", None)
model = getattr(e, "model", None)
if status_code is not None:
parts.append(f"status_code={status_code}")
if llm_provider is not None:
parts.append(f"llm_provider={llm_provider}")
if model is not None:
parts.append(f"model={model}")
raise ValueError(" | ".join(parts)) from e
def _encode_image_for_llm_prompt(image_data: bytes) -> str:

View File

@@ -0,0 +1,330 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is the response payload dict from the customer's endpoint
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
import httpx
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
"""No active hook configured for this hook point."""
class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if outcome.response_payload is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
)
return outcome.response_payload

View File

@@ -0,0 +1,121 @@
from datetime import datetime
from enum import Enum
from typing import Annotated
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import SecretStr
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------
class HookCreateRequest(BaseModel):
name: str = Field(min_length=1)
hook_point: HookPoint
endpoint_url: str = Field(min_length=1)
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None, uses HookPointSpec default
@field_validator("name", "endpoint_url")
@classmethod
def no_whitespace_only(cls, v: str) -> str:
if not v.strip():
raise ValueError("cannot be whitespace-only.")
return v
class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None
timeout_seconds: float | None = Field(default=None, gt=0)
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
if not self.model_fields_set:
raise ValueError("At least one field must be provided for an update.")
if "name" in self.model_fields_set and not (self.name or "").strip():
raise ValueError("name cannot be cleared.")
if (
"endpoint_url" in self.model_fields_set
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
raise ValueError(
"fail_strategy cannot be null; omit the field to leave it unchanged."
)
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
raise ValueError(
"timeout_seconds cannot be null; omit the field to leave it unchanged."
)
return self
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class HookPointMetaResponse(BaseModel):
hook_point: HookPoint
display_name: str
description: str
docs_url: str | None
input_schema: dict[str, Any]
output_schema: dict[str, Any]
default_timeout_seconds: float
default_fail_strategy: HookFailStrategy
fail_hard_description: str
class HookResponse(BaseModel):
id: int
name: str
hook_point: HookPoint
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
is_reachable: bool | None
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateStatus(str, Enum):
passed = "passed" # server responded (any status except 401/403)
auth_failed = "auth_failed" # server responded with 401 or 403
timeout = (
"timeout" # TCP connected, but read/write timed out (server exists but slow)
)
cannot_connect = "cannot_connect" # could not connect to the server
class HookValidateResponse(BaseModel):
status: HookValidateStatus
error_message: str | None = None
class HookExecutionRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime

View File

View File

@@ -0,0 +1,75 @@
from typing import Any
from typing import ClassVar
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
_REQUIRED_ATTRS = (
"hook_point",
"display_name",
"description",
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
"payload_model",
"response_model",
)
class HookPointSpec:
"""Static metadata and contract for a pipeline hook point.
Each concrete subclass represents exactly one hook point and is instantiated
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
get_hook_point_spec() or get_all_specs() from the registry over direct
instantiation.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Subclasses must define all attributes as class-level constants.
payload_model and response_model must be Pydantic BaseModel subclasses;
input_schema and output_schema are derived from them automatically.
"""
hook_point: HookPoint
display_name: str
description: str
default_timeout_seconds: float
fail_hard_description: str
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
payload_model: ClassVar[type[BaseModel]]
response_model: ClassVar[type[BaseModel]]
# Computed once at class definition time from payload_model / response_model.
input_schema: ClassVar[dict[str, Any]]
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every concrete subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
for attr in ("payload_model", "response_model"):
val = getattr(cls, attr, None)
if val is None or not (
isinstance(val, type) and issubclass(val, BaseModel)
):
raise TypeError(
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
)
cls.input_schema = cls.payload_model.model_json_schema()
cls.output_schema = cls.response_model.model_json_schema()

View File

@@ -0,0 +1,31 @@
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
# TODO(@Bo-Onyx): define payload and response fields
class DocumentIngestionPayload(BaseModel):
pass
class DocumentIngestionResponse(BaseModel):
pass
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
"""
hook_point = HookPoint.DOCUMENT_INGESTION
display_name = "Document Ingestion"
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -0,0 +1,70 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
query: str = Field(description="The raw query string exactly as the user typed it.")
user_email: str | None = Field(
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
)
class QueryProcessingResponse(BaseModel):
# Intentionally permissive — customer endpoints may return extra fields.
query: str | None = Field(
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, or absent = reject the query."
),
)
rejection_message: str | None = Field(
default=None,
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
)
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
Call site: inside handle_stream_message_objects() in
backend/onyx/chat/process_message.py, immediately after message_text is
assigned from the request and before create_new_chat_message() saves it.
This is the earliest possible point in the query pipeline:
- Raw query — unmodified, exactly as the user typed it
- No side effects yet — message has not been saved to DB
- User identity is available for user-specific logic
Supported use cases:
- Query rejection: block queries based on content or user context
- Query rewriting: normalize, expand, or modify the query
- PII removal: scrub sensitive data before the LLM sees it
- Access control: reject queries from certain users or groups
- Query auditing: log or track queries based on business rules
"""
hook_point = HookPoint.QUERY_PROCESSING
display_name = "Query Processing"
description = (
"Runs on every user query before it enters the pipeline. "
"Allows rewriting, filtering, or rejecting queries."
)
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
fail_hard_description = (
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -0,0 +1,45 @@
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
from onyx.hooks.points.query_processing import QueryProcessingSpec
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
_REGISTRY: dict[HookPoint, HookPointSpec] = {
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
}
def validate_registry() -> None:
"""Assert that every HookPoint enum value has a registered spec.
Call once at application startup (e.g. from the FastAPI lifespan hook).
Raises RuntimeError if any hook point is missing a spec.
"""
missing = set(HookPoint) - set(_REGISTRY)
if missing:
raise RuntimeError(
f"Hook point(s) have no registered spec: {missing}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
"""Returns the spec for a given hook point.
Raises ValueError if the hook point has no registered spec — this is a
programmer error; every HookPoint enum value must have a corresponding spec
in _REGISTRY.
"""
try:
return _REGISTRY[hook_point]
except KeyError:
raise ValueError(
f"No spec registered for hook point {hook_point!r}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_all_specs() -> list[HookPointSpec]:
"""Returns the specs for all registered hook points."""
return list(_REGISTRY.values())

View File

@@ -0,0 +1,5 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import contextlib
from collections.abc import Generator
@@ -21,7 +19,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -87,21 +85,14 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: DocumentBatchPrepareContext,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> DocumentChunkEnricher:
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
no_access = DocumentAccess.build(
user_emails=[],
@@ -111,30 +102,67 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
)
def _get_ancestor_ids_for_documents(
@@ -175,7 +203,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
enrichment: DocumentChunkEnricher,
result: BuildMetadataAwareChunksResult,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -199,7 +227,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -221,52 +249,3 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -27,7 +24,7 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -104,20 +101,13 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: DocumentBatchPrepareContext,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
no_access = DocumentAccess.build(
user_emails=[],
@@ -127,6 +117,7 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -147,6 +138,17 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -161,9 +163,15 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
@@ -173,16 +181,28 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -226,7 +246,7 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
enrichment: UserFileChunkEnricher,
result: BuildMetadataAwareChunksResult,
) -> None:
user_file_ids = [doc.id for doc in context.updatable_docs]
@@ -243,10 +263,8 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
str(user_file.id)
]
@@ -258,54 +276,8 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from collections.abc import Callable
from typing import cast
from typing import Protocol
from pydantic import BaseModel
@@ -92,15 +91,6 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -405,6 +395,12 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
llm = get_default_llm_with_vision()
if not llm:
if get_image_extraction_and_analysis_enabled():
logger.warning(
"Image analysis is enabled but no vision-capable LLM is "
"available — images will not be summarized. Configure a "
"vision model in the admin LLM settings."
)
# Even without LLM, we still convert to IndexingDocument with base Sections
return [
IndexingDocument(
@@ -676,7 +672,12 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult.empty(len(filtered_documents))
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -747,20 +748,14 @@ def index_doc_batch(
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
chunks=cast(list[DocAwareChunk], chunks_with_embeddings),
context=context,
)
metadata_aware_chunks = [
enricher.enrich_chunk(chunk, score)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
short_descriptor_list = [
chunk.to_short_descriptor() for chunk in metadata_aware_chunks
]
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
@@ -775,10 +770,10 @@ def index_doc_batch(
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=metadata_aware_chunks,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
@@ -820,7 +815,7 @@ def index_doc_batch(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
enrichment=enricher,
result=result,
)
assert primary_doc_idx_insertion_records is not None

View File

@@ -235,16 +235,12 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
class IndexingBatchAdapter(Protocol):
@@ -258,17 +254,18 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: "DocumentBatchPrepareContext",
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext: ...
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None: ...

View File

@@ -168,10 +168,23 @@ def get_default_llm_with_vision(
if model_supports_image_input(
default_model.name, default_model.llm_provider.provider
):
logger.info(
"Using default vision model: %s (provider=%s)",
default_model.name,
default_model.llm_provider.provider,
)
return create_vision_llm(
LLMProviderView.from_model(default_model.llm_provider),
default_model.name,
)
else:
logger.warning(
"Default vision model %s (provider=%s) does not support "
"image input — falling back to searching all providers",
default_model.name,
default_model.llm_provider.provider,
)
# Fall back to searching all providers
models = fetch_existing_models(
db_session=db_session,
@@ -179,6 +192,10 @@ def get_default_llm_with_vision(
)
if not models:
logger.warning(
"No LLM models with VISION or CHAT flow type found — "
"image summarization will be disabled"
)
return None
for model in models:
@@ -200,11 +217,25 @@ def get_default_llm_with_vision(
for model in sorted_models:
if model_supports_image_input(model.name, model.llm_provider.provider):
logger.info(
"Using fallback vision model: %s (provider=%s)",
model.name,
model.llm_provider.provider,
)
return create_vision_llm(
provider_map[model.llm_provider_id],
model.name,
)
checked_models = [
f"{m.name} (provider={m.llm_provider.provider})" for m in sorted_models
]
logger.warning(
"No vision-capable model found among %d candidates: %s"
"image summarization will be disabled",
len(sorted_models),
", ".join(checked_models),
)
return None

View File

@@ -530,6 +530,11 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:
optional_kwargs["tool_choice"] = tool_choice
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
@@ -538,7 +543,6 @@ class LitellmLLM(LLM):
custom_llm_provider=self._custom_llm_provider or None,
messages=messages,
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,

View File

@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.file_store.file_store import get_default_file_store
from onyx.hooks.registry import validate_registry
from onyx.server.api_key.api import router as api_key_router
from onyx.server.auth_check import check_router_auth
from onyx.server.documents.cc_pair import router as cc_pair_router
@@ -76,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -308,6 +310,7 @@ def validate_no_vector_db_settings() -> None:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
validate_registry()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
@@ -451,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -479,7 +479,9 @@ def is_zip_file(file: UploadFile) -> bool:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
files: list[UploadFile],
file_origin: FileOrigin = FileOrigin.CONNECTOR,
unzip: bool = True,
) -> FileUploadResponse:
# Skip directories and known macOS metadata entries
@@ -502,31 +504,46 @@ def upload_files(
if seen_zip:
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
seen_zip = True
# Validate the zip by opening it (catches corrupt/non-zip files)
with zipfile.ZipFile(file.file, "r") as zf:
zip_metadata_file_id = save_zip_metadata_to_file_store(
zf, file_store
)
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_id = file_store.save_file(
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=file_origin,
file_type=mime_type,
if unzip:
zip_metadata_file_id = save_zip_metadata_to_file_store(
zf, file_store
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_id = file_store.save_file(
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=file_origin,
file_type=mime_type,
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
continue
# Store the zip as-is (unzip=False)
file.file.seek(0)
file_id = file_store.save_file(
content=file.file,
display_name=file.filename,
file_origin=file_origin,
file_type=file.content_type or "application/zip",
)
deduped_file_paths.append(file_id)
deduped_file_names.append(file.filename)
continue
# Since we can't render docx files in the UI,
@@ -613,9 +630,10 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
def upload_files_api(
files: list[UploadFile],
unzip: bool = True,
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files, FileOrigin.OTHER)
return upload_files(files, FileOrigin.OTHER, unzip=unzip)
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)

View File

@@ -0,0 +1,453 @@
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import create_hook__no_commit
from onyx.db.hook import delete_hook__no_commit
from onyx.db.hook import get_hook_by_id
from onyx.db.hook import get_hook_execution_logs
from onyx.db.hook import get_hooks
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.api_dependencies import require_hook_enabled
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookExecutionRecord
from onyx.hooks.models import HookPointMetaResponse
from onyx.hooks.models import HookResponse
from onyx.hooks.models import HookUpdateRequest
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
logger = setup_logger()
# ---------------------------------------------------------------------------
# SSRF protection
# ---------------------------------------------------------------------------
def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
return HookResponse(
id=hook.id,
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
is_reachable=hook.is_reachable,
creator_email=(
creator_email
if creator_email is not None
else (hook.creator.email if hook.creator else None)
),
created_at=hook.created_at,
updated_at=hook.updated_at,
)
def _get_hook_or_404(
db_session: Session,
hook_id: int,
include_creator: bool = False,
) -> Hook:
hook = get_hook_by_id(
db_session=db_session,
hook_id=hook_id,
include_creator=include_creator,
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
return hook
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
"""Raise an appropriate OnyxError for a non-passed validation result."""
if validation.status == HookValidateStatus.auth_failed:
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
if validation.status == HookValidateStatus.timeout:
raise OnyxError(
OnyxErrorCode.GATEWAY_TIMEOUT,
f"Endpoint validation failed: {validation.error_message}",
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Endpoint validation failed: {validation.error_message}",
)
def _validate_endpoint(
endpoint_url: str,
api_key: str | None,
timeout_seconds: float,
) -> HookValidateResponse:
"""Check whether endpoint_url is reachable by sending an empty POST request.
We use POST since hook endpoints expect POST requests. The server will typically
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
the server is up and routable. A 401/403 response returns auth_failed
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
response = client.post(endpoint_url, headers=headers)
if response.status_code in (401, 403):
return HookValidateResponse(
status=HookValidateStatus.auth_failed,
error_message=f"Authentication failed (HTTP {response.status_code})",
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.timeout,
error_message="Endpoint timed out — consider increasing timeout_seconds.",
)
except Exception as exc:
logger.warning(
"Hook endpoint validation: connection error for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# ---------------------------------------------------------------------------
# Routers
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/admin/hooks")
# ---------------------------------------------------------------------------
# Hook endpoints
# ---------------------------------------------------------------------------
@router.get("/specs")
def get_hook_point_specs(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
) -> list[HookPointMetaResponse]:
return [
HookPointMetaResponse(
hook_point=spec.hook_point,
display_name=spec.display_name,
description=spec.description,
docs_url=spec.docs_url,
input_schema=spec.input_schema,
output_schema=spec.output_schema,
default_timeout_seconds=spec.default_timeout_seconds,
default_fail_strategy=spec.default_fail_strategy,
fail_hard_description=spec.fail_hard_description,
)
for spec in get_all_specs()
]
@router.get("")
def list_hooks(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookResponse]:
hooks = get_hooks(db_session=db_session, include_creator=True)
return [_hook_to_response(h) for h in hooks]
@router.post("")
def create_hook(
req: HookCreateRequest,
user: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
endpoint_url=req.endpoint_url,
api_key=api_key,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
hook = create_hook__no_commit(
db_session=db_session,
name=req.name,
hook_point=req.hook_point,
endpoint_url=req.endpoint_url,
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)
@router.get("/{hook_id}")
def get_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
return _hook_to_response(hook)
@router.patch("/{hook_id}")
def update_hook(
hook_id: int,
req: HookUpdateRequest,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
endpoint is re-validated using the effective values. For active hooks the update is
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
the update goes through regardless and is_reachable is updated to reflect the result.
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
"""
# api_key: UNSET = no change, None = clear, value = update
api_key: str | None | UnsetType
if "api_key" not in req.model_fields_set:
api_key = UNSET
elif req.api_key is None:
api_key = None
else:
api_key = req.api_key.get_secret_value()
endpoint_url_changing = "endpoint_url" in req.model_fields_set
api_key_changing = not isinstance(api_key, UnsetType)
timeout_changing = "timeout_seconds" in req.model_fields_set
validated_is_reachable: bool | None = None
if endpoint_url_changing or api_key_changing or timeout_changing:
existing = _get_hook_or_404(db_session, hook_id)
effective_url: str = (
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
)
effective_api_key: str | None = (
(api_key if not isinstance(api_key, UnsetType) else None)
if api_key_changing
else (
existing.api_key.get_value(apply_mask=False)
if existing.api_key
else None
)
)
effective_timeout: float = (
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
)
validation = _validate_endpoint(
endpoint_url=effective_url,
api_key=effective_api_key,
timeout_seconds=effective_timeout,
)
if existing.is_active and validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
validated_is_reachable = validation.status == HookValidateStatus.passed
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
name=req.name,
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
api_key=api_key,
fail_strategy=req.fail_strategy,
timeout_seconds=req.timeout_seconds,
is_reachable=validated_is_reachable,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.delete("/{hook_id}")
def delete_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> None:
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
db_session.commit()
@router.post("/{hook_id}/activate")
def activate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
# Persist is_reachable=False in a separate session so the request
# session has no commits on the failure path and the transaction
# boundary stays clean.
if hook.is_reachable is not False:
with get_session_with_current_tenant() as side_session:
update_hook__no_commit(
db_session=side_session, hook_id=hook_id, is_reachable=False
)
side_session.commit()
_raise_for_validation_failure(validation)
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=True,
is_reachable=True,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.post("/{hook_id}/validate")
def validate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookValidateResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
validation_passed = validation.status == HookValidateStatus.passed
if hook.is_reachable != validation_passed:
update_hook__no_commit(
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
)
db_session.commit()
return validation
@router.post("/{hook_id}/deactivate")
def deactivate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=False,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
# ---------------------------------------------------------------------------
# Execution log endpoints
# ---------------------------------------------------------------------------
@router.get("/{hook_id}/execution-logs")
def list_hook_execution_logs(
hook_id: int,
limit: int = Query(default=10, ge=1, le=100),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookExecutionRecord]:
_get_hook_or_404(db_session, hook_id)
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
return [
HookExecutionRecord(
error_message=log.error_message,
status_code=log.status_code,
duration_ms=log.duration_ms,
created_at=log.created_at,
)
for log in logs
]

View File

@@ -1,3 +1,4 @@
import hashlib
import mimetypes
from io import BytesIO
from typing import Any
@@ -83,6 +84,14 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
def __init__(self, tool_id: int, emitter: Emitter) -> None:
super().__init__(emitter=emitter)
self._id = tool_id
# Cache of (filename, content_hash) -> ci_file_id to avoid re-uploading
# the same file on every tool call iteration within the same agent session.
# Filename is included in the key so two files with identical bytes but
# different names each get their own upload slot.
# TTL assumption: code-interpreter file TTLs (typically hours) greatly
# exceed the lifetime of a single agent session (at most MAX_LLM_CYCLES
# iterations, typically a few minutes), so stale-ID eviction is not needed.
self._uploaded_file_cache: dict[tuple[str, str], str] = {}
@property
def id(self) -> int:
@@ -182,8 +191,13 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
content_hash = hashlib.sha256(chat_file.content).hexdigest()
cache_key = (file_name, content_hash)
ci_file_id = self._uploaded_file_cache.get(cache_key)
if ci_file_id is None:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
self._uploaded_file_cache[cache_key] = ci_file_id
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
@@ -299,14 +313,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Note: staged input files are intentionally not deleted here because
# _uploaded_file_cache reuses their file_ids across iterations. They are
# orphaned when the session ends, but the code interpreter cleans up
# stale files on its own TTL.
# Emit file_ids once files are processed
if generated_file_ids:

View File

@@ -74,7 +74,7 @@ def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
"""helper function to return an id given a string input"""
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
hash_obj = hashlib.md5(hash_input.encode("utf-8"), usedforsecurity=False)
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case

View File

@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
return validated_ip, hostname, port
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
def validate_outbound_http_url(
url: str,
*,
allow_private_network: bool = False,
https_only: bool = False,
) -> str:
"""
Validate a URL that will be used by backend outbound HTTP calls.
Args:
url: The URL to validate.
allow_private_network: If True, skip private/reserved IP checks.
https_only: If True, reject http:// URLs (only https:// is allowed).
Returns:
A normalized URL string with surrounding whitespace removed.
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
parsed = urlparse(normalized_url)
if parsed.scheme not in ("http", "https"):
if https_only:
if parsed.scheme != "https":
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
)
elif parsed.scheme not in ("http", "https"):
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
)

View File

@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.8.0
pypdf==6.9.1
# via
# onyx
# unstructured-client

View File

@@ -0,0 +1,274 @@
"""
External dependency unit tests for user file delete queue protections.
Verifies that the three mechanisms added to check_for_user_file_delete work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_DELETE_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in DELETING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_DELETE_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that delete_user_file_impl clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_delete_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_delete_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_for_user_file_delete,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file_delete,
)
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_deleting_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in DELETING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.DELETING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestDeleteQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "del_bp_user")
_create_deleting_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=USER_FILE_DELETE_MAX_QUEUE_DEPTH + 1),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestDeletePerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "del_guard_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "del_guard_set_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert (
0 < ttl <= CELERY_USER_FILE_DELETE_TASK_EXPIRES
), f"Guard key TTL {ttl}s is outside the expected range (0, {CELERY_USER_FILE_DELETE_TASK_EXPIRES}]"
finally:
redis_client.delete(guard_key)
class TestDeleteTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "del_expires_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.DELETE_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_DELETE
assert (
call.kwargs.get("expires") == CELERY_USER_FILE_DELETE_TASK_EXPIRES
), "Task must be submitted with the correct expires value to prevent stale task accumulation"
finally:
redis_client.delete(guard_key)
class TestDeleteWorkerClearsGuardKey:
"""process_single_user_file_delete removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so delete_user_file_impl returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file delete lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_delete_lock_key(user_file_id)
delete_lock = redis_client.lock(lock_key, timeout=10)
acquired = delete_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the delete lock for this test"
try:
process_single_user_file_delete.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if delete_lock.owned():
delete_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -153,13 +153,15 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -188,13 +190,15 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -225,13 +229,14 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -256,13 +261,14 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -294,11 +300,12 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -297,6 +297,10 @@ def index_batch_params(
class TestDocumentIndexOld:
"""Tests the old DocumentIndex interface."""
# TODO(ENG-3864)(andrei): Re-enable this test.
@pytest.mark.xfail(
reason="Flaky test: Retrieved chunks vary non-deterministically before and after changing user projects and personas. Likely a timing issue with the index being updated."
)
def test_update_single_can_clear_user_projects_and_personas(
self,
document_indices: list[DocumentIndex],

View File

@@ -29,6 +29,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DocumentQuery
@@ -96,6 +97,23 @@ def _patch_hybrid_search_normalization_pipeline(
)
def _patch_opensearch_match_highlights_disabled(
monkeypatch: pytest.MonkeyPatch, disabled: bool
) -> None:
"""
Patches OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED wherever necessary for this
test file.
"""
monkeypatch.setattr(
"onyx.configs.app_configs.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
disabled,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
disabled,
)
def _create_test_document_chunk(
document_id: str,
content: str,
@@ -226,7 +244,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Under test.
# Should not raise.
@@ -242,7 +260,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -271,7 +289,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
@@ -285,7 +303,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Under test and postcondition.
# Should return False before creation.
@@ -305,7 +323,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -340,7 +358,7 @@ class TestOpenSearchClient:
},
},
}
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test.
@@ -383,7 +401,7 @@ class TestOpenSearchClient:
"test_field": {"type": "keyword"},
},
}
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test and postcondition.
@@ -418,7 +436,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Create once - should succeed.
test_client.create_index(mappings=mappings, settings=settings)
@@ -461,7 +479,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -489,7 +507,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
docs = [
@@ -520,7 +538,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -548,7 +566,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
original_doc = _create_test_document_chunk(
@@ -583,7 +601,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test and postcondition.
@@ -602,7 +620,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -638,7 +656,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -659,7 +677,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index multiple documents.
@@ -735,7 +753,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Create a document to update.
@@ -784,7 +802,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test and postcondition.
@@ -804,11 +822,12 @@ class TestOpenSearchClient:
"""Tests all hybrid search configurations and pipelines."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents.
docs = {
@@ -881,8 +900,12 @@ class TestOpenSearchClient:
)
# Make sure the chunk contents are preserved.
for i, chunk in enumerate(results):
assert (
chunk.document_chunk == docs[chunk.document_chunk.document_id]
expected = docs[chunk.document_chunk.document_id]
assert chunk.document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(expected, k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
@@ -906,7 +929,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Note no documents were indexed.
@@ -942,12 +965,13 @@ class TestOpenSearchClient:
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
@@ -1038,7 +1062,12 @@ class TestOpenSearchClient:
# ordered; we're just assuming which doc will be the first result here.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == docs["public-doc"]
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
@@ -1046,7 +1075,12 @@ class TestOpenSearchClient:
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == docs["private-doc-user-a"]
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
@@ -1062,11 +1096,12 @@ class TestOpenSearchClient:
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with varying relevance to the query.
@@ -1193,7 +1228,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Although very unlikely in practice, let's use the same doc ID just to
@@ -1286,7 +1321,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Don't index any documents.
@@ -1313,7 +1348,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents.
@@ -1381,7 +1416,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
@@ -1458,7 +1493,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index docs with various ages.
@@ -1550,7 +1585,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents, one hidden one not.
@@ -1599,4 +1634,281 @@ class TestOpenSearchClient:
for result in results:
# Note each result must be from doc 1, which is not hidden.
expected_result = doc1_chunks[result.document_chunk.chunk_index]
assert result.document_chunk == expected_result
assert result.document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(expected_result, k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
def test_keyword_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests keyword search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
# Tests that we don't return documents that don't match keywords at
# all, even if they match filters.
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
document_id="private-but-not-relevant-doc-user-a",
chunk_index=0,
content="This text should not match the query at all",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
# Should not match private-but-not-relevant-doc-user-a.
query_text = "document content"
search_body = DocumentQuery.get_keyword_search_query(
query_text=query_text,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# This should be the highest-ranked result, as a higher percentage of
# the content matches the query.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
# Make sure there is some kind of match highlight.
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
assert results[1].score < results[0].score
def test_semantic_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests semantic search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
# Make this identical to the query vector to test that this
# result is returned first.
content_vector=_generate_test_vector(0.6),
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
# Make this different from the query vector to test that this
# result is returned second.
content_vector=_generate_test_vector(0.5),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
query_vector = _generate_test_vector(0.6)
search_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_vector,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# We explicitly expect this to be the highest-ranked result.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[0].score == 1.0
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert 0.0 < results[1].score < 1.0

View File

@@ -31,7 +31,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
)
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.configs.constants import SOURCE_TYPE
from onyx.context.search.models import IndexFilters
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
@@ -44,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
@@ -70,6 +70,7 @@ from onyx.document_index.vespa_constants import SOURCE_LINKS
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.document_index.vespa_constants import USER_PROJECT
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
from tests.external_dependency_unit.full_setup import ensure_full_deployment_setup
@@ -78,24 +79,22 @@ CHUNK_COUNT = 5
def _get_document_chunks_from_opensearch(
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient,
document_id: str,
tenant_state: TenantState,
) -> list[DocumentChunk]:
opensearch_client.refresh_index()
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
query_body = DocumentQuery.get_from_document_id_query(
document_id=document_id,
tenant_state=TenantState(tenant_id=current_tenant_id, multitenant=False),
index_filters=filters,
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
)
search_hits = opensearch_client.search(
body=query_body,
search_pipeline_id=None,
)
return [search_hit.document_chunk for search_hit in search_hits]
results: list[DocumentChunk] = []
for i in range(CHUNK_COUNT):
document_chunk_id: str = get_opensearch_doc_chunk_id(
tenant_state=tenant_state,
document_id=document_id,
chunk_index=i,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
)
result = opensearch_client.get_document(document_chunk_id)
results.append(result)
return results
def _delete_document_chunks_from_opensearch(
@@ -452,10 +451,13 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Under test.
result = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -477,7 +479,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -522,6 +524,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
@@ -536,7 +541,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
assert result_1 is True
@@ -559,7 +564,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Under test.
# Run the remainder of the migration.
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -583,7 +588,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -630,6 +635,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
@@ -646,7 +654,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
assert result_1 is True
@@ -691,7 +699,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
),
):
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -728,7 +736,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
),
):
result_3 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -752,7 +760,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -840,24 +848,25 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
chunk["content"] = (
f"Different content {chunk[CHUNK_ID]} for {test_documents[0].id}"
)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
chunks_for_document_in_opensearch, _ = (
transform_vespa_chunks_to_opensearch_chunks(
document_in_opensearch,
TenantState(tenant_id=get_current_tenant_id(), multitenant=False),
tenant_state,
{},
)
)
opensearch_client.bulk_index_documents(
documents=chunks_for_document_in_opensearch,
tenant_state=TenantState(
tenant_id=get_current_tenant_id(), multitenant=False
),
tenant_state=tenant_state,
update_if_exists=True,
)
# Under test.
result = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -878,7 +887,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -922,11 +931,14 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Under test.
# First run.
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -947,7 +959,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -960,7 +972,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Under test.
# Second run.
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -982,7 +994,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)

View File

@@ -1219,15 +1219,16 @@ def test_code_interpreter_receives_chat_files(
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Verify: file uploaded, code executed via streaming, staged file cleaned up
# Verify: file uploaded and code executed via streaming.
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
delete_requests = mock_ci_server.get_requests(method="DELETE")
assert len(delete_requests) == 1
assert delete_requests[0].path.startswith("/v1/files/")
# Staged input files are intentionally NOT deleted — PythonTool caches their
# file IDs across agent-loop iterations to avoid re-uploading on every call.
# The code interpreter cleans them up via its own TTL.
assert len(mock_ci_server.get_requests(method="DELETE")) == 0
execute_body = mock_ci_server.get_requests(
method="POST", path="/v1/execute/stream"

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import os
import subprocess
import sys
import time
import uuid
from collections.abc import Generator
@@ -28,6 +29,9 @@ _BACKEND_DIR = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
_DROP_SCHEMA_MAX_RETRIES = 3
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
# ---------------------------------------------------------------------------
# Helpers
@@ -50,6 +54,39 @@ def _run_script(
)
def _force_drop_schema(engine: Engine, schema: str) -> None:
"""Terminate backends using *schema* then drop it, retrying on deadlock.
Background Celery workers may discover test schemas (they match the
``tenant_`` prefix) and hold locks on tables inside them. A bare
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
first kill their connections and retry if we still hit a deadlock.
"""
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
try:
with engine.connect() as conn:
conn.execute(
text(
"""
SELECT pg_terminate_backend(l.pid)
FROM pg_locks l
JOIN pg_class c ON c.oid = l.relation
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND l.pid != pg_backend_pid()
"""
),
{"schema": schema},
)
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
return
except Exception:
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
raise
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -104,9 +141,7 @@ def tenant_schema_at_head(
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -123,9 +158,7 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -150,9 +183,7 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,63 @@
"""
Unit test verifying that the upload API path sends tasks with expires=.
The upload_files_to_user_files_with_indexing function must include expires=
on every send_task call to prevent phantom task accumulation if the worker
is down or slow.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.models import UserFile
from onyx.db.projects import upload_files_to_user_files_with_indexing
def _make_mock_user_file() -> MagicMock:
uf = MagicMock(spec=UserFile)
uf.id = str(uuid4())
return uf
@patch("onyx.db.projects.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.db.projects.create_user_files")
@patch(
"onyx.background.celery.versioned_apps.client.app",
new_callable=MagicMock,
)
def test_send_task_includes_expires(
mock_client_app: MagicMock,
mock_create: MagicMock,
mock_tenant: MagicMock, # noqa: ARG001
) -> None:
"""Every send_task call from the upload path must include expires=."""
user_files = [_make_mock_user_file(), _make_mock_user_file()]
mock_create.return_value = MagicMock(
user_files=user_files,
rejected_files=[],
id_to_temp_id={},
)
mock_user = MagicMock()
mock_db_session = MagicMock()
upload_files_to_user_files_with_indexing(
files=[],
project_id=None,
user=mock_user,
temp_id_map=None,
db_session=mock_db_session,
)
assert mock_client_app.send_task.call_count == len(user_files)
for call in mock_client_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires") == CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), "send_task must include expires= to prevent phantom task accumulation"

View File

@@ -1,208 +0,0 @@
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.indexing.models import DocMetadataAwareIndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int,
) -> DocMetadataAwareIndexChunk:
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
doc = Document(
id=doc_id,
sections=[TextSection(text="test", link="http://test.com")],
source=DocumentSource.FILE,
semantic_identifier="test_doc",
metadata={},
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links={0: "http://test.com"},
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings={"full_embedding": [0.1] * 10, "mini_chunk_embeddings": []},
title_embedding=[0.1] * 10,
tenant_id="test_tenant",
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
ancestor_hierarchy_node_ids=[],
)
def _make_index() -> OpenSearchDocumentIndex:
"""Creates an OpenSearchDocumentIndex with a mocked client."""
mock_client = MagicMock()
mock_client.bulk_index_documents = MagicMock()
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
index._index_name = "test_index"
index._client = mock_client
index._tenant_state = tenant_state
return index
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=0,
new_chunk_cnt=chunk_count,
),
},
)
@patch("onyx.document_index.opensearch.opensearch_document_index.CHUNKS_PER_BATCH", 100)
def test_single_doc_under_batch_limit_flushes_once() -> None:
"""A document with fewer chunks than CHUNKS_PER_BATCH should flush once."""
index = _make_index()
doc_id = "doc_1"
num_chunks = 50
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert index._client.bulk_index_documents.call_count == 1
batch_arg = index._client.bulk_index_documents.call_args_list[0]
assert len(batch_arg.kwargs["documents"]) == num_chunks
@patch("onyx.document_index.opensearch.opensearch_document_index.CHUNKS_PER_BATCH", 100)
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
"""A document with more chunks than CHUNKS_PER_BATCH should flush multiple times."""
index = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
assert index._client.bulk_index_documents.call_count == 3
batch_sizes = [
len(call.kwargs["documents"])
for call in index._client.bulk_index_documents.call_args_list
]
assert batch_sizes == [100, 100, 50]
@patch("onyx.document_index.opensearch.opensearch_document_index.CHUNKS_PER_BATCH", 100)
def test_single_doc_exactly_at_batch_limit() -> None:
"""A document with exactly CHUNKS_PER_BATCH chunks should flush once
(the flush happens on the next chunk, not at the boundary)."""
index = _make_index()
doc_id = "doc_1"
num_chunks = 100
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
# so final flush handles all 100
# Actually: the elif fires when len(current_chunks) >= 100, which happens
# when current_chunks has 100 items and the 101st chunk arrives.
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
# No 101st chunk arrives, so the final flush handles all 100.
assert index._client.bulk_index_documents.call_count == 1
@patch("onyx.document_index.opensearch.opensearch_document_index.CHUNKS_PER_BATCH", 100)
def test_single_doc_one_over_batch_limit() -> None:
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
the 101st is flushed at the end."""
index = _make_index()
doc_id = "doc_1"
num_chunks = 101
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert index._client.bulk_index_documents.call_count == 2
batch_sizes = [
len(call.kwargs["documents"])
for call in index._client.bulk_index_documents.call_args_list
]
assert batch_sizes == [100, 1]
@patch("onyx.document_index.opensearch.opensearch_document_index.CHUNKS_PER_BATCH", 100)
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
"""Multiple documents each under the batch limit should flush once per document."""
index = _make_index()
chunks = []
for doc_idx in range(3):
doc_id = f"doc_{doc_idx}"
for chunk_idx in range(50):
chunks.append(_make_chunk(doc_id, chunk_idx))
metadata = IndexingMetadata(
doc_id_to_chunk_cnt_diff={
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
for i in range(3)
},
)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 3 documents = 3 flushes (one per doc boundary + final)
assert index._client.bulk_index_documents.call_count == 3
@patch("onyx.document_index.opensearch.opensearch_document_index.CHUNKS_PER_BATCH", 100)
def test_delete_called_once_per_document() -> None:
"""Even with multiple flushes for a single document, delete should only be
called once per document."""
index = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0) as mock_delete:
index.index(chunks, metadata)
mock_delete.assert_called_once_with(doc_id, None)

View File

@@ -1,272 +0,0 @@
"""Unit tests for OpenSearchDocumentIndex.index().
These tests mock the OpenSearch client and verify the buffered
flush-by-document logic, DocumentInsertionRecord construction, and
delete-before-insert semantics.
"""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int = 0,
content: str = "test content",
chunk_count: int | None = None,
) -> DocMetadataAwareIndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test_doc",
sections=[TextSection(text=content, link=None)],
source=DocumentSource.NOT_APPLICABLE,
metadata={},
chunk_count=chunk_count,
)
index_chunk = IndexChunk(
chunk_id=chunk_id,
blurb=content[:50],
content=content,
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=[0.1] * 10,
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=index_chunk,
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
def _make_indexing_metadata(
doc_ids: list[str],
old_counts: list[int],
new_counts: list[int],
) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=old,
new_chunk_cnt=new,
)
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
}
)
def _make_os_index(mock_client: MagicMock) -> OpenSearchDocumentIndex:
"""Create an OpenSearchDocumentIndex with a mocked client."""
with patch.object(
OpenSearchDocumentIndex,
"__init__",
lambda _self, *_a, **_kw: None,
):
idx = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
idx._index_name = "test_index"
idx._tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
idx._client = mock_client
return idx
def test_index_single_new_doc() -> None:
"""Indexing a single new document returns one record with already_existed=False."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
idx = _make_os_index(mock_client)
# Patch delete to return 0 (no existing chunks)
with patch.object(idx, "delete", return_value=0) as mock_delete:
chunk = _make_chunk("doc1")
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[1])
results = idx.index(chunks=[chunk], indexing_metadata=metadata)
assert len(results) == 1
assert results[0].document_id == "doc1"
assert results[0].already_existed is False
mock_delete.assert_called_once()
mock_client.bulk_index_documents.assert_called_once()
def test_index_existing_doc_already_existed_true() -> None:
"""Re-indexing a doc with previous chunks returns already_existed=True."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
idx = _make_os_index(mock_client)
with patch.object(idx, "delete", return_value=5):
chunk = _make_chunk("doc1")
metadata = _make_indexing_metadata(["doc1"], old_counts=[5], new_counts=[1])
results = idx.index(chunks=[chunk], indexing_metadata=metadata)
assert len(results) == 1
assert results[0].already_existed is True
def test_index_multiple_docs_flushed_separately() -> None:
"""Chunks from different documents are flushed in separate bulk calls."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
idx = _make_os_index(mock_client)
with patch.object(idx, "delete", return_value=0):
chunks = [
_make_chunk("doc1", chunk_id=0),
_make_chunk("doc1", chunk_id=1),
_make_chunk("doc2", chunk_id=0),
]
metadata = _make_indexing_metadata(
["doc1", "doc2"], old_counts=[0, 0], new_counts=[2, 1]
)
results = idx.index(chunks=chunks, indexing_metadata=metadata)
result_map = {r.document_id: r.already_existed for r in results}
assert len(result_map) == 2
assert result_map["doc1"] is False
assert result_map["doc2"] is False
# Two separate flushes: one for doc1 (2 chunks), one for doc2 (1 chunk)
assert mock_client.bulk_index_documents.call_count == 2
def test_index_deletes_before_inserting() -> None:
"""For each document, delete is called before bulk_index_documents."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
call_order: list[str] = []
idx = _make_os_index(mock_client)
def track_delete(*_args: object, **_kwargs: object) -> int:
call_order.append("delete")
return 3
def track_bulk(*_args: object, **_kwargs: object) -> None:
call_order.append("bulk_index")
mock_client.bulk_index_documents.side_effect = track_bulk
with patch.object(idx, "delete", side_effect=track_delete):
chunk = _make_chunk("doc1")
metadata = _make_indexing_metadata(["doc1"], old_counts=[3], new_counts=[1])
idx.index(chunks=[chunk], indexing_metadata=metadata)
assert call_order == ["delete", "bulk_index"]
def test_index_delete_called_once_per_doc() -> None:
"""Delete is called only once per document, even with multiple chunks."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
idx = _make_os_index(mock_client)
with patch.object(idx, "delete", return_value=0) as mock_delete:
# 3 chunks, all same doc — should only delete once
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(3)]
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[3])
idx.index(chunks=chunks, indexing_metadata=metadata)
mock_delete.assert_called_once()
def test_index_flushes_on_doc_boundary() -> None:
"""When doc ID changes in the stream, the previous doc's chunks are flushed."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
idx = _make_os_index(mock_client)
bulk_call_chunk_counts: list[int] = []
def track_bulk(documents: list[object], **_kwargs: object) -> None:
bulk_call_chunk_counts.append(len(documents))
mock_client.bulk_index_documents.side_effect = track_bulk
with patch.object(idx, "delete", return_value=0):
chunks = [
_make_chunk("doc1", chunk_id=0),
_make_chunk("doc1", chunk_id=1),
_make_chunk("doc1", chunk_id=2),
_make_chunk("doc2", chunk_id=0),
_make_chunk("doc2", chunk_id=1),
]
metadata = _make_indexing_metadata(
["doc1", "doc2"], old_counts=[0, 0], new_counts=[3, 2]
)
idx.index(chunks=chunks, indexing_metadata=metadata)
# First flush: 3 chunks for doc1, second flush: 2 chunks for doc2
assert bulk_call_chunk_counts == [3, 2]
def test_index_with_generator_input() -> None:
"""The index method works with a generator (iterable) input, not just lists."""
mock_client = MagicMock()
mock_client.bulk_index_documents.return_value = None
idx = _make_os_index(mock_client)
consumed: list[int] = []
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
consumed.append(i)
yield _make_chunk("doc1", chunk_id=i)
with patch.object(idx, "delete", return_value=0):
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[3])
results = idx.index(chunks=chunk_gen(), indexing_metadata=metadata)
assert consumed == [0, 1, 2]
assert len(results) == 1

View File

@@ -1,417 +0,0 @@
"""Unit tests for VespaDocumentIndex.index().
These tests mock all external I/O (HTTP calls, thread pools) and verify
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
construction.
"""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.access.models import DocumentAccess
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int = 0,
content: str = "test content",
) -> DocMetadataAwareIndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test_doc",
sections=[TextSection(text=content, link=None)],
source=DocumentSource.NOT_APPLICABLE,
metadata={},
)
index_chunk = IndexChunk(
chunk_id=chunk_id,
blurb=content[:50],
content=content,
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=index_chunk,
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
def _make_indexing_metadata(
doc_ids: list[str],
old_counts: list[int],
new_counts: list[int],
) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=old,
new_chunk_cnt=new,
)
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
}
)
def _stub_enrich(
doc_id: str,
old_chunk_cnt: int,
) -> EnrichedDocumentIndexingInfo:
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
return EnrichedDocumentIndexingInfo(
doc_id=doc_id,
chunk_start_index=0,
old_version=False,
chunk_end_index=old_chunk_cnt,
)
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
def test_index_single_new_doc(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""Indexing a single new document returns one record with already_existed=False."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunk = _make_chunk("doc1")
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[1])
results = index.index(chunks=[chunk], indexing_metadata=metadata)
assert len(results) == 1
assert results[0].document_id == "doc1"
assert results[0].already_existed is False
# batch_index_vespa_chunks should be called once with a single cleaned chunk
mock_batch_index.assert_called_once()
call_kwargs = mock_batch_index.call_args
indexed_chunks = call_kwargs.kwargs["chunks"]
assert len(indexed_chunks) == 1
assert indexed_chunks[0].source_document.id == "doc1"
assert call_kwargs.kwargs["index_name"] == "test_index"
assert call_kwargs.kwargs["multitenant"] is False
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
def test_index_existing_doc_already_existed_true(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock,
mock_delete: MagicMock,
mock_batch_index: MagicMock,
) -> None:
"""Re-indexing a doc with previous chunks deletes old chunks, indexes
new ones, and returns already_existed=True."""
fake_chunk_ids = [uuid4(), uuid4()]
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=5)
mock_get_chunk_ids.return_value = fake_chunk_ids
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunk = _make_chunk("doc1")
metadata = _make_indexing_metadata(["doc1"], old_counts=[5], new_counts=[1])
results = index.index(chunks=[chunk], indexing_metadata=metadata)
assert len(results) == 1
assert results[0].already_existed is True
# Old chunks should be deleted
mock_delete.assert_called_once()
delete_kwargs = mock_delete.call_args.kwargs
assert delete_kwargs["doc_chunk_ids"] == fake_chunk_ids
assert delete_kwargs["index_name"] == "test_index"
# New chunk should be indexed
mock_batch_index.assert_called_once()
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
assert len(indexed_chunks) == 1
assert indexed_chunks[0].source_document.id == "doc1"
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
def test_index_multiple_docs(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""Indexing multiple documents returns one record per unique document."""
mock_enrich.side_effect = [
_stub_enrich("doc1", old_chunk_cnt=0),
_stub_enrich("doc2", old_chunk_cnt=3),
]
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunks = [
_make_chunk("doc1", chunk_id=0),
_make_chunk("doc1", chunk_id=1),
_make_chunk("doc2", chunk_id=0),
]
metadata = _make_indexing_metadata(
["doc1", "doc2"], old_counts=[0, 3], new_counts=[2, 1]
)
results = index.index(chunks=chunks, indexing_metadata=metadata)
result_map = {r.document_id: r.already_existed for r in results}
assert len(result_map) == 2
assert result_map["doc1"] is False
assert result_map["doc2"] is True
# All 3 chunks fit in one batch (BATCH_SIZE=128), so one call
mock_batch_index.assert_called_once()
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
assert len(indexed_chunks) == 3
indexed_doc_ids = [c.source_document.id for c in indexed_chunks]
assert indexed_doc_ids == ["doc1", "doc1", "doc2"]
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
def test_index_cleans_doc_ids(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""Documents with invalid Vespa characters get cleaned IDs, but
the returned DocumentInsertionRecord uses the original ID."""
doc_id_with_quote = "doc'1"
mock_enrich.return_value = _stub_enrich(doc_id_with_quote, old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunk = _make_chunk(doc_id_with_quote)
metadata = _make_indexing_metadata(
[doc_id_with_quote], old_counts=[0], new_counts=[1]
)
results = index.index(chunks=[chunk], indexing_metadata=metadata)
assert len(results) == 1
# The returned ID should be the original (unclean) ID
assert results[0].document_id == doc_id_with_quote
# The chunk passed to batch_index_vespa_chunks should have the cleaned ID
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
assert len(indexed_chunks) == 1
assert indexed_chunks[0].source_document.id == "doc_1" # quote replaced with _
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
def test_index_deduplicates_doc_ids_in_results(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""Multiple chunks from the same document produce only one
DocumentInsertionRecord."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(5)]
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[5])
results = index.index(chunks=chunks, indexing_metadata=metadata)
assert len(results) == 1
assert results[0].document_id == "doc1"
# All 5 chunks should be passed to batch_index_vespa_chunks
mock_batch_index.assert_called_once()
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
assert len(indexed_chunks) == 5
assert all(c.source_document.id == "doc1" for c in indexed_chunks)
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
@patch(
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
3,
)
def test_index_respects_batch_size(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
multiple times with correctly sized batches."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
results = index.index(chunks=chunks, indexing_metadata=metadata)
assert len(results) == 1
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
assert mock_batch_index.call_count == 3
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
assert batch_sizes == [3, 3, 1]
# Verify all chunks are accounted for and in order
all_indexed = [
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
]
assert len(all_indexed) == 7
assert [c.chunk_id for c in all_indexed] == list(range(7))
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
def test_index_streams_chunks_lazily(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock, # noqa: ARG001
) -> None:
"""Chunks are consumed lazily via a generator, not materialized upfront."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
consumed: list[int] = []
def chunk_generator() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
consumed.append(i)
yield _make_chunk("doc1", chunk_id=i)
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[3])
# Before calling index, nothing consumed
gen = chunk_generator()
assert len(consumed) == 0
results = index.index(chunks=gen, indexing_metadata=metadata)
# After calling index, all chunks should have been consumed
assert consumed == [0, 1, 2]
assert len(results) == 1

View File

@@ -0,0 +1,45 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Parent 2 0 R
>>
endobj
xref
0 5
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000113 00000 n
0000000162 00000 n
trailer
<<
/Size 5
/Root 3 0 R
/Info 1 0 R
>>
startxref
256
%%EOF

View File

@@ -0,0 +1,89 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 2
/Kids [ 4 0 R 6 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 47
>>
stream
BT /F1 12 Tf 50 150 Td (Page one content) Tj ET
endstream
endobj
6 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 7 0 R
/Parent 2 0 R
>>
endobj
7 0 obj
<<
/Length 47
>>
stream
BT /F1 12 Tf 50 150 Td (Page two content) Tj ET
endstream
endobj
xref
0 8
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000119 00000 n
0000000168 00000 n
0000000349 00000 n
0000000446 00000 n
0000000627 00000 n
trailer
<<
/Size 8
/Root 3 0 R
/Info 1 0 R
>>
startxref
724
%%EOF

View File

@@ -0,0 +1,62 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
BT /F1 12 Tf 50 150 Td (Hello World) Tj ET
endstream
endobj
xref
0 6
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000113 00000 n
0000000162 00000 n
0000000343 00000 n
trailer
<<
/Size 6
/Root 3 0 R
/Info 1 0 R
>>
startxref
435
%%EOF

View File

@@ -0,0 +1,64 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
/Title (My Title)
/Author (Jane Doe)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 35
>>
stream
BT /F1 12 Tf 50 150 Td (test) Tj ET
endstream
endobj
xref
0 6
0000000000 65535 f
0000000015 00000 n
0000000091 00000 n
0000000150 00000 n
0000000199 00000 n
0000000380 00000 n
trailer
<<
/Size 6
/Root 3 0 R
/Info 1 0 R
>>
startxref
465
%%EOF

View File

@@ -0,0 +1,89 @@
"""
Unit tests for image summarization error handling.
Verifies that:
1. LLM errors produce actionable error messages (not base64 dumps)
2. Unsupported MIME type logs include the magic bytes and size
3. The ValueError raised on LLM failure preserves the original exception
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.file_processing.image_summarization import _summarize_image
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
class TestSummarizeImageErrorMessage:
"""_summarize_image must not dump base64 image data into error messages."""
def test_error_message_contains_exception_type_not_base64(self) -> None:
"""The ValueError should contain the original exception info, not message payloads."""
mock_llm = MagicMock()
mock_llm.invoke.side_effect = RuntimeError("Connection timeout")
# A fake base64-encoded image string (should NOT appear in the error)
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg..."
with pytest.raises(ValueError, match="RuntimeError: Connection timeout"):
_summarize_image(fake_encoded, mock_llm, query="test")
def test_error_message_does_not_contain_base64(self) -> None:
"""Ensure base64 data is never included in the error message."""
mock_llm = MagicMock()
mock_llm.invoke.side_effect = RuntimeError("API error")
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA"
with pytest.raises(ValueError) as exc_info:
_summarize_image(fake_encoded, mock_llm)
error_str = str(exc_info.value)
assert "base64" not in error_str
assert "iVBOR" not in error_str
def test_original_exception_is_chained(self) -> None:
"""The ValueError should chain the original exception via __cause__."""
mock_llm = MagicMock()
original = RuntimeError("upstream failure")
mock_llm.invoke.side_effect = original
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
assert exc_info.value.__cause__ is original
class TestUnsupportedMimeTypeLogging:
"""summarize_image_with_error_handling should log useful info for unsupported formats."""
@patch(
"onyx.file_processing.image_summarization.summarize_image_pipeline",
side_effect=__import__(
"onyx.file_processing.image_summarization",
fromlist=["UnsupportedImageFormatError"],
).UnsupportedImageFormatError("unsupported"),
)
def test_logs_magic_bytes_and_size(
self, mock_pipeline: MagicMock # noqa: ARG002
) -> None:
"""The info log should include magic bytes hex and image size."""
mock_llm = MagicMock()
# TIFF magic bytes (not in the supported list)
image_data = b"\x49\x49\x2a\x00" + b"\x00" * 100
with patch("onyx.file_processing.image_summarization.logger") as mock_logger:
result = summarize_image_with_error_handling(
llm=mock_llm,
image_data=image_data,
context_name="test_image.tiff",
)
assert result is None
mock_logger.info.assert_called_once()
log_args = mock_logger.info.call_args
# Check the format string args contain magic bytes and size
assert "49492a00" in str(log_args)
assert "104" in str(log_args) # 4 + 100 bytes

View File

@@ -0,0 +1,141 @@
"""
Unit tests verifying that LiteLLM error details are extracted and surfaced
in image summarization error messages.
When the LLM call fails, the error handler should include the status_code,
llm_provider, and model from LiteLLM exceptions so operators can diagnose
the root cause (rate limit, content filter, unsupported vision, etc.)
without needing to dig through LiteLLM internals.
"""
from unittest.mock import MagicMock
import pytest
from onyx.file_processing.image_summarization import _summarize_image
def _make_litellm_style_error(
*,
message: str = "API error",
status_code: int | None = None,
llm_provider: str | None = None,
model: str | None = None,
) -> RuntimeError:
"""Create an exception with LiteLLM-style attributes."""
exc = RuntimeError(message)
if status_code is not None:
exc.status_code = status_code # type: ignore[attr-defined]
if llm_provider is not None:
exc.llm_provider = llm_provider # type: ignore[attr-defined]
if model is not None:
exc.model = model # type: ignore[attr-defined]
return exc
class TestLiteLLMErrorExtraction:
"""Verify that LiteLLM error attributes are included in the ValueError."""
def test_status_code_included(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Content filter triggered",
status_code=400,
llm_provider="azure",
model="gpt-4o",
)
with pytest.raises(ValueError, match="status_code=400"):
_summarize_image("data:image/png;base64,abc", mock_llm)
def test_llm_provider_included(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Bad request",
status_code=400,
llm_provider="azure",
)
with pytest.raises(ValueError, match="llm_provider=azure"):
_summarize_image("data:image/png;base64,abc", mock_llm)
def test_model_included(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Bad request",
model="gpt-4o",
)
with pytest.raises(ValueError, match="model=gpt-4o"):
_summarize_image("data:image/png;base64,abc", mock_llm)
def test_all_fields_in_single_message(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Rate limit exceeded",
status_code=429,
llm_provider="azure",
model="gpt-4o",
)
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
msg = str(exc_info.value)
assert "status_code=429" in msg
assert "llm_provider=azure" in msg
assert "model=gpt-4o" in msg
assert "Rate limit exceeded" in msg
def test_plain_exception_without_litellm_attrs(self) -> None:
"""Non-LiteLLM exceptions should still produce a useful message."""
mock_llm = MagicMock()
mock_llm.invoke.side_effect = ConnectionError("Connection refused")
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
msg = str(exc_info.value)
assert "ConnectionError" in msg
assert "Connection refused" in msg
# Should not contain status_code/llm_provider/model
assert "status_code" not in msg
assert "llm_provider" not in msg
def test_no_base64_in_error(self) -> None:
"""Error messages must not contain the full base64 image payload.
Some LiteLLM exceptions echo the request body (including base64 images)
in their message. The truncation guard ensures the bulk of such a
payload is stripped from the re-raised ValueError.
"""
mock_llm = MagicMock()
# Build a long base64-like payload that exceeds the 512-char truncation
fake_b64_payload = "iVBORw0KGgo" * 100 # ~1100 chars
fake_b64 = f"data:image/png;base64,{fake_b64_payload}"
mock_llm.invoke.side_effect = RuntimeError(
f"Request failed for payload: {fake_b64}"
)
with pytest.raises(ValueError) as exc_info:
_summarize_image(fake_b64, mock_llm)
msg = str(exc_info.value)
# The full payload must not appear (truncation should have kicked in)
assert fake_b64_payload not in msg
assert "truncated" in msg
def test_long_error_message_truncated(self) -> None:
"""Exception messages longer than 512 chars are truncated."""
mock_llm = MagicMock()
long_msg = "x" * 1000
mock_llm.invoke.side_effect = RuntimeError(long_msg)
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
msg = str(exc_info.value)
assert "truncated" in msg
# The full 1000-char string should not appear
assert long_msg not in msg

View File

@@ -0,0 +1,124 @@
"""Unit tests for pypdf-dependent PDF processing functions.
Tests cover:
- read_pdf_file: text extraction, metadata, encrypted PDFs, image extraction
- pdf_to_text: convenience wrapper
- is_pdf_protected: password protection detection
Fixture PDFs live in ./fixtures/ and are pre-built so the test layer has no
dependency on pypdf internals (pypdf.generic).
"""
from io import BytesIO
from pathlib import Path
from onyx.file_processing.extract_file_text import pdf_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.password_validation import is_pdf_protected
FIXTURES = Path(__file__).parent / "fixtures"
def _load(name: str) -> BytesIO:
return BytesIO((FIXTURES / name).read_bytes())
# ── read_pdf_file ────────────────────────────────────────────────────────
class TestReadPdfFile:
def test_basic_text_extraction(self) -> None:
text, _, images = read_pdf_file(_load("simple.pdf"))
assert "Hello World" in text
assert images == []
def test_multi_page_text_extraction(self) -> None:
text, _, _ = read_pdf_file(_load("multipage.pdf"))
assert "Page one content" in text
assert "Page two content" in text
def test_metadata_extraction(self) -> None:
_, pdf_metadata, _ = read_pdf_file(_load("with_metadata.pdf"))
assert pdf_metadata.get("Title") == "My Title"
assert pdf_metadata.get("Author") == "Jane Doe"
def test_encrypted_pdf_with_correct_password(self) -> None:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="pass123")
assert "Secret Content" in text
def test_encrypted_pdf_without_password(self) -> None:
text, _, _ = read_pdf_file(_load("encrypted.pdf"))
assert text == ""
def test_encrypted_pdf_with_wrong_password(self) -> None:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
assert text == ""
def test_empty_pdf(self) -> None:
text, _, _ = read_pdf_file(_load("empty.pdf"))
assert text.strip() == ""
def test_invalid_pdf_returns_empty(self) -> None:
text, _, images = read_pdf_file(BytesIO(b"this is not a pdf"))
assert text == ""
assert images == []
def test_image_extraction_disabled_by_default(self) -> None:
_, _, images = read_pdf_file(_load("with_image.pdf"))
assert images == []
def test_image_extraction_collects_images(self) -> None:
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert len(images) == 1
img_bytes, img_name = images[0]
assert len(img_bytes) > 0
assert img_name # non-empty name
def test_image_callback_streams_instead_of_collecting(self) -> None:
"""With image_callback, images are streamed via callback and not accumulated."""
collected: list[tuple[bytes, str]] = []
def callback(data: bytes, name: str) -> None:
collected.append((data, name))
_, _, images = read_pdf_file(
_load("with_image.pdf"), extract_images=True, image_callback=callback
)
# Callback received the image
assert len(collected) == 1
assert len(collected[0][0]) > 0
# Returned list is empty when callback is used
assert images == []
# ── pdf_to_text ──────────────────────────────────────────────────────────
class TestPdfToText:
def test_returns_text(self) -> None:
assert "Hello World" in pdf_to_text(_load("simple.pdf"))
def test_with_password(self) -> None:
assert "Secret Content" in pdf_to_text(
_load("encrypted.pdf"), pdf_pass="pass123"
)
def test_encrypted_without_password_returns_empty(self) -> None:
assert pdf_to_text(_load("encrypted.pdf")) == ""
# ── is_pdf_protected ─────────────────────────────────────────────────────
class TestIsPdfProtected:
def test_unprotected_pdf(self) -> None:
assert is_pdf_protected(_load("simple.pdf")) is False
def test_protected_pdf(self) -> None:
assert is_pdf_protected(_load("encrypted.pdf")) is True
def test_preserves_file_position(self) -> None:
pdf = _load("simple.pdf")
pdf.seek(42)
is_pdf_protected(pdf)
assert pdf.tell() == 42

View File

@@ -0,0 +1,19 @@
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
def test_init_subclass_raises_for_missing_attrs() -> None:
with pytest.raises(TypeError, match="must define class attributes"):
class IncompleteSpec(HookPointSpec):
hook_point = HookPoint.QUERY_PROCESSING
# missing display_name, description, payload_model, response_model, etc.
class _Payload(BaseModel):
pass
payload_model = _Payload
response_model = _Payload

View File

@@ -0,0 +1,541 @@
"""Unit tests for the hook executor."""
import json
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
def _make_hook(
*,
is_active: bool = True,
endpoint_url: str | None = "https://hook.example.com/query",
api_key: MagicMock | None = None,
timeout_seconds: float = 5.0,
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
hook.endpoint_url = endpoint_url
hook.api_key = api_key
hook.timeout_seconds = timeout_seconds
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
return hook
def _make_api_key(value: str) -> MagicMock:
api_key = MagicMock()
api_key.get_value.return_value = value
return api_key
def _make_response(
*,
status_code: int = 200,
json_return: Any = _RESPONSE_PAYLOAD,
json_side_effect: Exception | None = None,
) -> MagicMock:
"""Build a response mock with controllable json() behaviour."""
response = MagicMock()
response.status_code = status_code
if json_side_effect is not None:
response.json.side_effect = json_side_effect
else:
response.json.return_value = json_return
return response
def _setup_client(
mock_client_cls: MagicMock,
*,
response: MagicMock | None = None,
side_effect: Exception | None = None,
) -> MagicMock:
"""Wire up the httpx.Client mock and return the inner client.
If side_effect is an httpx.HTTPStatusError, it is raised from
raise_for_status() (matching real httpx behaviour) and post() returns a
response mock with the matching status_code set. All other exceptions are
raised directly from post().
"""
mock_client = MagicMock()
if isinstance(side_effect, httpx.HTTPStatusError):
error_response = MagicMock()
error_response.status_code = side_effect.response.status_code
error_response.raise_for_status.side_effect = side_effect
mock_client.post = MagicMock(return_value=error_response)
else:
mock_client.post = MagicMock(
side_effect=side_effect, return_value=response if not side_effect else None
)
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def db_session() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# Early-exit guards (no HTTP call, no DB writes)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hooks_available,hook",
[
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
hooks_available: bool,
hook: MagicMock | None,
) -> None:
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSkipped)
mock_update.assert_not_called()
mock_log.assert_not_called()
# ---------------------------------------------------------------------------
# Successful HTTP call
# ---------------------------------------------------------------------------
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
hook = _make_hook()
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
"""Deduplication guard: a hook already at is_reachable=True that succeeds
must not trigger a DB write."""
hook = _make_hook(is_reachable=True)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
mock_update.assert_not_called()
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(json_return=["unexpected", "list"]),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-dict" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
"""response.json() raising must be treated as failure with SOFT strategy.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-JSON" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
# ---------------------------------------------------------------------------
# HTTP failure paths
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"exception,fail_strategy,expected_type,expected_is_reachable",
[
# NetworkError → is_reachable=False
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="connect_error_soft",
),
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.HARD,
OnyxError,
False,
id="connect_error_hard",
),
# 401/403 → is_reachable=False (api_key revoked)
pytest.param(
httpx.HTTPStatusError(
"401",
request=MagicMock(),
response=MagicMock(status_code=401, text="Unauthorized"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="auth_401_soft",
),
pytest.param(
httpx.HTTPStatusError(
"403",
request=MagicMock(),
response=MagicMock(status_code=403, text="Forbidden"),
),
HookFailStrategy.HARD,
OnyxError,
False,
id="auth_403_hard",
),
# TimeoutException → no is_reachable write (None)
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="timeout_soft",
),
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.HARD,
OnyxError,
None,
id="timeout_hard",
),
# Other HTTP errors → no is_reachable write (None)
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="http_status_error_soft",
),
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.HARD,
OnyxError,
None,
id="http_status_error_hard",
),
],
)
def test_http_failure_paths(
db_session: MagicMock,
exception: Exception,
fail_strategy: HookFailStrategy,
expected_type: type,
expected_is_reachable: bool | None,
) -> None:
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, expected_type)
if expected_is_reachable is None:
mock_update.assert_not_called()
else:
mock_update.assert_called_once()
_, kwargs = mock_update.call_args
assert kwargs["is_reachable"] is expected_is_reachable
# ---------------------------------------------------------------------------
# Authorization header
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"api_key_value,expect_auth_header",
[
pytest.param("secret-token", True, id="api_key_present"),
pytest.param(None, False, id="api_key_absent"),
],
)
def test_authorization_header(
db_session: MagicMock,
api_key_value: str | None,
expect_auth_header: bool,
) -> None:
api_key = _make_api_key(api_key_value) if api_key_value else None
hook = _make_hook(api_key=api_key)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
_, call_kwargs = mock_client.post.call_args
if expect_auth_header:
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
else:
assert "Authorization" not in call_kwargs["headers"]
# ---------------------------------------------------------------------------
# Persist session failure
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"http_exception,expected_result",
[
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expected_result: Any,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response() if not http_exception else None,
side_effect=http_exception,
)
if expected_result is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == expected_result
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
"""is_reachable update failing (e.g. concurrent hook deletion) must not
prevent the execution log from being written.
Simulates the production failure path: update_hook__no_commit raises
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
between the initial lookup and the reachable update.
"""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
mock_log.assert_called_once()

View File

@@ -0,0 +1,86 @@
import pytest
from pydantic import ValidationError
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookUpdateRequest
def test_hook_update_request_rejects_empty() -> None:
# No fields supplied at all
with pytest.raises(ValidationError, match="At least one field must be provided"):
HookUpdateRequest()
def test_hook_update_request_rejects_null_name_when_only_field() -> None:
# Explicitly setting name=None is rejected as name cannot be cleared
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=None)
def test_hook_update_request_accepts_single_field() -> None:
req = HookUpdateRequest(name="new name")
assert req.name == "new name"
def test_hook_update_request_accepts_partial_fields() -> None:
req = HookUpdateRequest(fail_strategy=HookFailStrategy.SOFT, timeout_seconds=10.0)
assert req.fail_strategy == HookFailStrategy.SOFT
assert req.timeout_seconds == 10.0
assert req.name is None
def test_hook_update_request_rejects_null_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=None, fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_empty_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name="", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_null_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url=None, fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_empty_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url="", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_allows_null_api_key() -> None:
# api_key=null is valid — means "clear the api key"
req = HookUpdateRequest(api_key=None)
assert req.api_key is None
assert "api_key" in req.model_fields_set
def test_hook_update_request_rejects_whitespace_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=" ", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_whitespace_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url=" ", fail_strategy=HookFailStrategy.SOFT)
def test_hook_create_request_rejects_whitespace_name() -> None:
with pytest.raises(ValidationError, match="whitespace-only"):
HookCreateRequest(
name=" ",
hook_point=HookPoint.QUERY_PROCESSING,
endpoint_url="https://example.com/hook",
)
def test_hook_create_request_rejects_whitespace_endpoint_url() -> None:
with pytest.raises(ValidationError, match="whitespace-only"):
HookCreateRequest(
name="my hook",
hook_point=HookPoint.QUERY_PROCESSING,
endpoint_url=" ",
)

View File

@@ -0,0 +1,62 @@
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.query_processing import QueryProcessingSpec
def test_hook_point_is_query_processing() -> None:
assert QueryProcessingSpec().hook_point == HookPoint.QUERY_PROCESSING
def test_default_fail_strategy_is_hard() -> None:
assert QueryProcessingSpec().default_fail_strategy == HookFailStrategy.HARD
def test_default_timeout_seconds() -> None:
# User is actively waiting — 5s is the documented contract for this hook point
assert QueryProcessingSpec().default_timeout_seconds == 5.0
def test_input_schema_required_fields() -> None:
schema = QueryProcessingSpec().input_schema
assert schema["type"] == "object"
required = schema["required"]
assert "query" in required
assert "user_email" in required
assert "chat_session_id" in required
def test_input_schema_chat_session_id_is_string() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert props["chat_session_id"]["type"] == "string"
def test_input_schema_query_is_string() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert props["query"]["type"] == "string"
def test_input_schema_user_email_is_nullable() -> None:
props = QueryProcessingSpec().input_schema["properties"]
# Pydantic v2 emits anyOf for nullable fields
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
def test_output_schema_query_is_optional() -> None:
# query defaults to None (absent = reject); not required in the schema
schema = QueryProcessingSpec().output_schema
assert "query" not in schema.get("required", [])
def test_output_schema_query_is_nullable() -> None:
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
props = QueryProcessingSpec().output_schema["properties"]
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
def test_output_schema_rejection_message_is_optional() -> None:
schema = QueryProcessingSpec().output_schema
assert "rejection_message" not in schema.get("required", [])
def test_input_schema_no_additional_properties() -> None:
assert QueryProcessingSpec().input_schema.get("additionalProperties") is False

View File

@@ -0,0 +1,47 @@
import pytest
from onyx.db.enums import HookPoint
from onyx.hooks import registry as registry_module
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.registry import validate_registry
def test_registry_covers_all_hook_points() -> None:
"""Every HookPoint enum member must have a registered spec."""
assert {s.hook_point for s in get_all_specs()} == set(
HookPoint
), f"Missing specs for: {set(HookPoint) - {s.hook_point for s in get_all_specs()}}"
def test_get_hook_point_spec_returns_correct_spec() -> None:
for hook_point in HookPoint:
spec = get_hook_point_spec(hook_point)
assert spec.hook_point == hook_point
def test_get_all_specs_returns_all() -> None:
specs = get_all_specs()
assert len(specs) == len(HookPoint)
assert {s.hook_point for s in specs} == set(HookPoint)
def test_get_hook_point_spec_raises_for_unregistered(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""get_hook_point_spec raises ValueError when a hook point has no spec."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
with pytest.raises(ValueError, match="No spec registered for hook point"):
get_hook_point_spec(HookPoint.QUERY_PROCESSING)
def test_validate_registry_passes() -> None:
validate_registry() # should not raise with the real registry
def test_validate_registry_raises_for_incomplete(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(registry_module, "_REGISTRY", {})
with pytest.raises(RuntimeError, match="Hook point\\(s\\) have no registered spec"):
validate_registry()

View File

@@ -116,7 +116,7 @@ def _run_adapter_build(
project_ids_map: dict[str, list[int]],
persona_ids_map: dict[str, list[int]],
) -> list[DocMetadataAwareIndexChunk]:
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
with all external dependencies mocked."""
from onyx.indexing.adapters.user_file_indexing_adapter import (
UserFileIndexingAdapter,
@@ -155,12 +155,14 @@ def _run_adapter_build(
side_effect=Exception("no LLM in tests"),
),
):
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id="test_tenant",
chunks=[chunk],
context=context,
)
return [enricher.enrich_chunk(chunk, 1.0)]
return result.chunks
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:

View File

@@ -256,7 +256,6 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -412,7 +411,6 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -1431,3 +1429,36 @@ def test_strip_tool_content_merges_consecutive_tool_results() -> None:
assert "sunny 72F" in merged
assert "tc_2" in merged
assert "headline news" in merged
def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> None:
"""Regression test for providers (e.g. Fireworks) that reject tool_choice=null.
When no tools are provided, tool_choice must not be forwarded to
litellm.completion() at all — not even as None.
"""
messages: LanguageModelInput = [UserMessage(content="Hello!")]
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello!"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = mock_stream_chunks
default_multi_llm.invoke(messages, tools=None)
_, kwargs = mock_completion.call_args
assert (
"tool_choice" not in kwargs
), "tool_choice must not be sent to providers when no tools are provided"

View File

@@ -0,0 +1,130 @@
"""
Unit tests for vision model selection logging in get_default_llm_with_vision.
Verifies that operators get clear feedback about:
1. Which vision model was selected and why
2. When the default vision model doesn't support image input
3. When no vision-capable model exists at all
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.llm.factory import get_default_llm_with_vision
_FACTORY = "onyx.llm.factory"
def _make_mock_model(
*,
name: str = "gpt-4o",
provider: str = "openai",
provider_id: int = 1,
flow_types: list[str] | None = None,
) -> MagicMock:
model = MagicMock()
model.name = name
model.llm_provider_id = provider_id
model.llm_provider.provider = provider
model.llm_model_flow_types = flow_types or []
return model
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model")
@patch(f"{_FACTORY}.model_supports_image_input", return_value=True)
@patch(f"{_FACTORY}.llm_from_provider")
@patch(f"{_FACTORY}.LLMProviderView")
@patch(f"{_FACTORY}.logger")
def test_logs_when_using_default_vision_model(
mock_logger: MagicMock,
mock_provider_view: MagicMock, # noqa: ARG001
mock_llm_from: MagicMock, # noqa: ARG001
mock_supports: MagicMock, # noqa: ARG001
mock_fetch_default: MagicMock,
mock_session: MagicMock, # noqa: ARG001
) -> None:
mock_fetch_default.return_value = _make_mock_model(name="gpt-4o", provider="azure")
get_default_llm_with_vision()
mock_logger.info.assert_called_once()
log_msg = mock_logger.info.call_args[0][0]
assert "default vision model" in log_msg.lower()
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model")
@patch(f"{_FACTORY}.model_supports_image_input", return_value=False)
@patch(f"{_FACTORY}.fetch_existing_models", return_value=[])
@patch(f"{_FACTORY}.logger")
def test_warns_when_default_model_lacks_vision(
mock_logger: MagicMock,
mock_fetch_models: MagicMock, # noqa: ARG001
mock_supports: MagicMock, # noqa: ARG001
mock_fetch_default: MagicMock,
mock_session: MagicMock, # noqa: ARG001
) -> None:
mock_fetch_default.return_value = _make_mock_model(
name="text-only-model", provider="azure"
)
result = get_default_llm_with_vision()
assert result is None
# Should have warned about the default model not supporting vision
warning_calls = [
call
for call in mock_logger.warning.call_args_list
if "does not support" in str(call)
]
assert len(warning_calls) >= 1
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model", return_value=None)
@patch(f"{_FACTORY}.fetch_existing_models", return_value=[])
@patch(f"{_FACTORY}.logger")
def test_warns_when_no_models_exist(
mock_logger: MagicMock,
mock_fetch_models: MagicMock, # noqa: ARG001
mock_fetch_default: MagicMock, # noqa: ARG001
mock_session: MagicMock, # noqa: ARG001
) -> None:
result = get_default_llm_with_vision()
assert result is None
mock_logger.warning.assert_called_once()
log_msg = mock_logger.warning.call_args[0][0]
assert "no llm models" in log_msg.lower()
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model", return_value=None)
@patch(f"{_FACTORY}.fetch_existing_models")
@patch(f"{_FACTORY}.model_supports_image_input", return_value=False)
@patch(f"{_FACTORY}.LLMProviderView")
@patch(f"{_FACTORY}.logger")
def test_warns_when_no_model_supports_vision(
mock_logger: MagicMock,
mock_provider_view: MagicMock, # noqa: ARG001
mock_supports: MagicMock, # noqa: ARG001
mock_fetch_models: MagicMock,
mock_fetch_default: MagicMock, # noqa: ARG001
mock_session: MagicMock, # noqa: ARG001
) -> None:
mock_fetch_models.return_value = [
_make_mock_model(name="text-model-1", provider="openai"),
_make_mock_model(name="text-model-2", provider="azure", provider_id=2),
]
result = get_default_llm_with_vision()
assert result is None
warning_calls = [
call
for call in mock_logger.warning.call_args_list
if "no vision-capable model" in str(call).lower()
]
assert len(warning_calls) == 1

View File

@@ -0,0 +1,278 @@
"""Unit tests for onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_URL = "https://example.com/hook"
_API_KEY = "secret"
_TIMEOUT = 5.0
def _mock_response(status_code: int) -> MagicMock:
response = MagicMock()
response.status_code = status_code
return response
# ---------------------------------------------------------------------------
# _check_ssrf_safety
# ---------------------------------------------------------------------------
class TestCheckSsrfSafety:
def _call(self, url: str) -> None:
_check_ssrf_safety(url)
# --- scheme checks ---
def test_https_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
@pytest.mark.parametrize(
"url", ["http://example.com/hook", "ftp://example.com/hook"]
)
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@pytest.mark.parametrize(
"ip",
[
pytest.param("127.0.0.1", id="loopback"),
pytest.param("10.0.0.1", id="RFC1918-A"),
pytest.param("172.16.0.1", id="RFC1918-B"),
pytest.param("192.168.1.1", id="RFC1918-C"),
pytest.param("169.254.169.254", id="link-local-IMDS"),
pytest.param("100.64.0.1", id="shared-address-space"),
pytest.param("::1", id="IPv6-loopback"),
pytest.param("fc00::1", id="IPv6-ULA"),
pytest.param("fe80::1", id="IPv6-link-local"),
],
)
def test_private_ip_is_blocked(self, ip: str) -> None:
with (
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
pytest.raises(OnyxError) as exc_info,
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
def test_dns_resolution_failure_raises(self) -> None:
import socket
with (
patch(
"onyx.utils.url.socket.getaddrinfo",
side_effect=socket.gaierror("name not found"),
),
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
# _validate_endpoint
# ---------------------------------------------------------------------------
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(status_code)
)
result = self._call()
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
httpx.ReadTimeout("read timeout"),
httpx.WriteTimeout("write timeout"),
httpx.PoolTimeout("pool timeout"),
],
)
def test_read_write_pool_timeout_returns_timeout(
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
# Covers DNS failures, TLS errors, and other connection-level errors.
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectError("name resolution failed")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
ConnectionRefusedError("refused")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key="mykey")
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key=None)
_, kwargs = mock_post.call_args
assert "Authorization" not in kwargs["headers"]
# ---------------------------------------------------------------------------
# _raise_for_validation_failure
# ---------------------------------------------------------------------------
class TestRaiseForValidationFailure:
@pytest.mark.parametrize(
"status, expected_code",
[
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
],
)
def test_raises_correct_error_code(
self, status: HookValidateStatus, expected_code: OnyxErrorCode
) -> None:
validation = HookValidateResponse(status=status, error_message="some error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.error_code == expected_code
def test_auth_failed_passes_error_message_directly(self) -> None:
validation = HookValidateResponse(
status=HookValidateStatus.auth_failed, error_message="bad credentials"
)
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "bad credentials"
@pytest.mark.parametrize(
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
)
def test_timeout_and_cannot_connect_wrap_error_message(
self, status: HookValidateStatus
) -> None:
validation = HookValidateResponse(status=status, error_message="raw error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "Endpoint validation failed: raw error"
# ---------------------------------------------------------------------------
# HookValidateStatus enum string values (API contract)
# ---------------------------------------------------------------------------
class TestHookValidateStatusValues:
@pytest.mark.parametrize(
"status, expected",
[
(HookValidateStatus.passed, "passed"),
(HookValidateStatus.auth_failed, "auth_failed"),
(HookValidateStatus.timeout, "timeout"),
(HookValidateStatus.cannot_connect, "cannot_connect"),
],
)
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
assert status == expected

View File

@@ -0,0 +1,109 @@
import io
import zipfile
from unittest.mock import MagicMock
from unittest.mock import patch
from zipfile import BadZipFile
import pytest
from fastapi import UploadFile
from starlette.datastructures import Headers
from onyx.configs.constants import FileOrigin
from onyx.server.documents.connector import upload_files
def _create_test_zip() -> bytes:
"""Create a simple in-memory zip file containing two text files."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("file1.txt", "hello")
zf.writestr("file2.txt", "world")
return buf.getvalue()
def _make_upload_file(content: bytes, filename: str, content_type: str) -> UploadFile:
return UploadFile(
file=io.BytesIO(content),
filename=filename,
headers=Headers({"content-type": content_type}),
)
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_zip_with_unzip_true_extracts_files(
mock_get_store: MagicMock,
) -> None:
"""When unzip=True (default), a zip upload is extracted into individual files."""
mock_store = MagicMock()
mock_store.save_file.side_effect = lambda **kwargs: f"id-{kwargs['display_name']}"
mock_get_store.return_value = mock_store
zip_bytes = _create_test_zip()
upload = _make_upload_file(zip_bytes, "test.zip", "application/zip")
result = upload_files([upload], FileOrigin.CONNECTOR)
# Should have extracted the two individual files, not stored the zip itself
assert len(result.file_paths) == 2
assert "id-file1.txt" in result.file_paths
assert "id-file2.txt" in result.file_paths
assert "file1.txt" in result.file_names
assert "file2.txt" in result.file_names
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_zip_with_unzip_false_stores_zip_as_is(
mock_get_store: MagicMock,
) -> None:
"""When unzip=False, the zip file is stored as-is without extraction."""
mock_store = MagicMock()
mock_store.save_file.return_value = "zip-file-id"
mock_get_store.return_value = mock_store
zip_bytes = _create_test_zip()
upload = _make_upload_file(zip_bytes, "site_export.zip", "application/zip")
result = upload_files([upload], FileOrigin.CONNECTOR, unzip=False)
# Should store exactly one file (the zip itself)
assert len(result.file_paths) == 1
assert result.file_paths[0] == "zip-file-id"
assert result.file_names == ["site_export.zip"]
# No zip metadata should be created
assert result.zip_metadata_file_id is None
# Verify the stored content is a valid zip
saved_content: io.BytesIO = mock_store.save_file.call_args[1]["content"]
saved_content.seek(0)
with zipfile.ZipFile(saved_content, "r") as zf:
assert set(zf.namelist()) == {"file1.txt", "file2.txt"}
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_invalid_zip_with_unzip_false_raises(
mock_get_store: MagicMock,
) -> None:
"""An invalid zip is rejected even when unzip=False (validation still runs)."""
mock_get_store.return_value = MagicMock()
bad_zip = _make_upload_file(b"not a zip", "bad.zip", "application/zip")
with pytest.raises(BadZipFile):
upload_files([bad_zip], FileOrigin.CONNECTOR, unzip=False)
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_multiple_zips_rejected_when_unzip_false(
mock_get_store: MagicMock,
) -> None:
"""The seen_zip guard rejects a second zip even when unzip=False."""
mock_store = MagicMock()
mock_store.save_file.return_value = "zip-id"
mock_get_store.return_value = mock_store
zip_bytes = _create_test_zip()
zip1 = _make_upload_file(zip_bytes, "a.zip", "application/zip")
zip2 = _make_upload_file(zip_bytes, "b.zip", "application/zip")
with pytest.raises(Exception, match="Only one zip file"):
upload_files([zip1, zip2], FileOrigin.CONNECTOR, unzip=False)

View File

@@ -0,0 +1,208 @@
"""Unit tests for PythonTool file-upload caching.
Verifies that PythonTool reuses code-interpreter file IDs across multiple
run() calls within the same session instead of re-uploading identical content
on every agent loop iteration.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.tools.models import ChatFile
from onyx.tools.models import PythonToolOverrideKwargs
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamResultEvent,
)
from onyx.tools.tool_implementations.python.python_tool import PythonTool
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
def _make_stream_result() -> StreamResultEvent:
return StreamResultEvent(
exit_code=0,
timed_out=False,
duration_ms=10,
files=[],
)
def _make_tool() -> PythonTool:
emitter = MagicMock()
return PythonTool(tool_id=1, emitter=emitter)
def _make_override(files: list[ChatFile]) -> PythonToolOverrideKwargs:
return PythonToolOverrideKwargs(chat_files=files)
def _run_tool(tool: PythonTool, mock_client: MagicMock, files: list[ChatFile]) -> None:
"""Call tool.run() with a mocked CodeInterpreterClient context manager."""
from onyx.server.query_and_chat.placement import Placement
mock_client.execute_streaming.return_value = iter([_make_stream_result()])
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=mock_client)
ctx.__exit__ = MagicMock(return_value=False)
placement = Placement(turn_index=0, tab_index=0)
override = _make_override(files)
with patch(f"{TOOL_MODULE}.CodeInterpreterClient", return_value=ctx):
tool.run(placement=placement, override_kwargs=override, code="print('hi')")
# ---------------------------------------------------------------------------
# Cache hit: same content uploaded in a second call reuses the file_id
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_same_file_uploaded_only_once_across_two_runs() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.return_value = "file-id-abc"
pptx_content = b"fake pptx bytes"
files = [ChatFile(filename="report.pptx", content=pptx_content)]
_run_tool(tool, client, files)
_run_tool(tool, client, files)
# upload_file should only have been called once across both runs
client.upload_file.assert_called_once_with(pptx_content, "report.pptx")
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_cached_file_id_is_staged_on_second_run() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.return_value = "file-id-abc"
files = [ChatFile(filename="data.pptx", content=b"content")]
_run_tool(tool, client, files)
# On the second run, execute_streaming should still receive the file
client.execute_streaming.return_value = iter([_make_stream_result()])
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=client)
ctx.__exit__ = MagicMock(return_value=False)
from onyx.server.query_and_chat.placement import Placement
placement = Placement(turn_index=1, tab_index=0)
with patch(f"{TOOL_MODULE}.CodeInterpreterClient", return_value=ctx):
tool.run(
placement=placement,
override_kwargs=_make_override(files),
code="print('hi')",
)
# The second execute_streaming call should include the file
_, kwargs = client.execute_streaming.call_args
staged_files = kwargs.get("files") or []
assert any(f["file_id"] == "file-id-abc" for f in staged_files)
# ---------------------------------------------------------------------------
# Cache miss: different content triggers a new upload
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_different_file_content_uploaded_separately() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.side_effect = ["file-id-v1", "file-id-v2"]
file_v1 = ChatFile(filename="report.pptx", content=b"version 1")
file_v2 = ChatFile(filename="report.pptx", content=b"version 2")
_run_tool(tool, client, [file_v1])
_run_tool(tool, client, [file_v2])
assert client.upload_file.call_count == 2
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_multiple_distinct_files_each_uploaded_once() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.side_effect = ["id-a", "id-b"]
files = [
ChatFile(filename="a.pptx", content=b"aaa"),
ChatFile(filename="b.xlsx", content=b"bbb"),
]
_run_tool(tool, client, files)
_run_tool(tool, client, files)
# Two distinct files — each uploaded exactly once
assert client.upload_file.call_count == 2
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_same_content_different_filename_uploaded_separately() -> None:
# Identical bytes but different names must each get their own upload slot
# so both files appear under their respective paths in the workspace.
tool = _make_tool()
client = MagicMock()
client.upload_file.side_effect = ["id-v1", "id-v2"]
same_bytes = b"shared content"
files = [
ChatFile(filename="report_v1.csv", content=same_bytes),
ChatFile(filename="report_v2.csv", content=same_bytes),
]
_run_tool(tool, client, files)
assert client.upload_file.call_count == 2
# ---------------------------------------------------------------------------
# No cross-instance sharing: a fresh PythonTool re-uploads everything
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_new_tool_instance_re_uploads_file() -> None:
client = MagicMock()
client.upload_file.side_effect = ["id-session-1", "id-session-2"]
files = [ChatFile(filename="deck.pptx", content=b"slide data")]
tool_session_1 = _make_tool()
_run_tool(tool_session_1, client, files)
tool_session_2 = _make_tool()
_run_tool(tool_session_2, client, files)
# Different instances — each uploads independently
assert client.upload_file.call_count == 2
# ---------------------------------------------------------------------------
# Upload failure: failed upload is not cached, retried next run
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_upload_failure_not_cached() -> None:
tool = _make_tool()
client = MagicMock()
# First call raises, second succeeds
client.upload_file.side_effect = [Exception("network error"), "file-id-ok"]
files = [ChatFile(filename="slides.pptx", content=b"data")]
# First run — upload fails, file is skipped but not cached
_run_tool(tool, client, files)
# Second run — should attempt upload again
_run_tool(tool, client, files)
assert client.upload_file.call_count == 2

93
cubic.yaml Normal file
View File

@@ -0,0 +1,93 @@
# yaml-language-server: $schema=https://cubic.dev/schema/cubic-repository-config.schema.json
version: 1
reviews:
enabled: true
sensitivity: medium
incremental_commits: true
check_drafts: false
custom_instructions: |
Use explicit type annotations for variables to enhance code clarity,
especially when moving type hints around in the code.
Use `contributing_guides/best_practices.md` as core review context.
Prefer consistency with existing patterns, fix issues in code you touch,
avoid tacking new features onto muddy interfaces, fail loudly instead of
silently swallowing errors, keep code strictly typed, preserve clear state
boundaries, remove duplicate or dead logic, break up overly long functions,
avoid hidden import-time side effects, respect module boundaries, and favor
correctness-by-construction over relying on callers to use an API correctly.
Reference these files for additional context:
- `contributing_guides/best_practices.md` — Best practices for contributing to the codebase
- `CLAUDE.md` — Project instructions and coding standards
- `backend/alembic/README.md` — Migration guidance, including multi-tenant migration behavior
- `deployment/helm/charts/onyx/values-lite.yaml` — Lite deployment Helm values and service assumptions
- `deployment/docker_compose/docker-compose.onyx-lite.yml` — Lite deployment Docker Compose overlay and disabled service behavior
ignore:
files:
- greptile.json
- cubic.yaml
custom_rules:
- name: TODO format
description: >
Whenever a TODO is added, there must always be an associated name or
ticket in the style of TODO(name): ... or TODO(1234): ...
- name: Frontend standards
description: >
For frontend changes, enforce all standards described in the
web/AGENTS.md file.
include:
- web/**
- desktop/**
- name: No debugging code
description: >
Remove temporary debugging code before merging to production,
especially tenant-specific debugging logs.
- name: No hardcoded booleans
description: >
When hardcoding a boolean variable to a constant value, remove the
variable entirely and clean up all places where it's used rather than
just setting it to a constant.
- name: Multi-tenant awareness
description: >
Code changes must consider both multi-tenant and single-tenant
deployments. In multi-tenant mode, preserve tenant isolation, ensure
tenant context is propagated correctly, and avoid assumptions that only
hold for a single shared schema or globally shared state. In
single-tenant mode, avoid introducing unnecessary tenant-specific
requirements or cloud-only control-plane dependencies.
- name: Onyx lite compatibility
description: >
Code changes must consider both regular Onyx deployments and Onyx lite
deployments. Lite deployments disable the vector DB, Redis, model
servers, and background workers by default, use PostgreSQL-backed
cache/auth/file storage, and rely on the API server to handle
background work. Do not assume those services are available unless the
code path is explicitly limited to full deployments.
- name: OnyxError over HTTPException
description: >
Never raise HTTPException directly in business code. Use
`raise OnyxError(OnyxErrorCode.XXX, "message")` from
`onyx.error_handling.exceptions`. A global FastAPI exception handler
converts OnyxError into structured JSON responses with
{"error_code": "...", "detail": "..."}. Error codes are defined in
`onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors
with dynamic HTTP status codes, use `status_code_override`:
`raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`.
include:
- backend/**/*.py
issues:
fix_with_cubic_buttons: true
pr_comment_fixes: true
fix_commits_to_pr: true

View File

@@ -489,20 +489,18 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -290,25 +290,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -314,21 +314,19 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/sslcerts:/etc/nginx/sslcerts
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
env_file:
- .env.nginx
environment:

View File

@@ -333,25 +333,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -202,20 +202,18 @@ services:
ports:
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -477,7 +477,10 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
# Mount templates read-only; the startup command copies them into
# the writable /etc/nginx/conf.d/ inside the container. This avoids
# "Permission denied" errors on Windows Docker bind mounts.
- ../data/nginx:/nginx-templates:ro
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
# - ../data/certbot/conf:/etc/letsencrypt
# - ../data/certbot/www:/var/www/certbot
@@ -489,12 +492,13 @@ services:
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not receive any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
# PRODUCTION: Change to app.conf.template.prod for production nginx config
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
cache:
image: redis:7.4-alpine

File diff suppressed because it is too large Load Diff

View File

@@ -96,8 +96,8 @@ fi
# When --lite is passed as a flag, lower resource thresholds early (before the
# resource check). When lite is chosen interactively, the thresholds are adjusted
# inside the new-deployment flow, after the resource check has already passed
# with the standard thresholds — which is the safer direction.
# after the resource check has already passed with the standard thresholds —
# which is the safer direction.
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
@@ -110,9 +110,6 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
# Build the -f flags for docker compose.
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
# (used by shutdown/delete-data so users don't need to remember --lite).
# Without the argument, the lite overlay is only included when --lite was
# explicitly passed — preventing install/start from silently staying in
# lite mode just because the file exists on disk from a prior run.
compose_file_args() {
local auto_detect="${1:-false}"
local args="-f docker-compose.yml"
@@ -177,7 +174,7 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
[[ "$NO_PROMPT" = false ]]
}
prompt_or_default() {
@@ -207,6 +204,16 @@ prompt_yn_or_default() {
fi
}
confirm_action() {
local description="$1"
prompt_yn_or_default "Install ${description}? (Y/n) [default: Y] " "Y"
if [[ "$REPLY" =~ ^[Nn] ]]; then
print_warning "Skipping: ${description}"
return 1
fi
return 0
}
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
@@ -395,6 +402,11 @@ fi
if ! command -v docker &> /dev/null; then
if [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; then
print_info "Docker is required but not installed."
if ! confirm_action "Docker Engine"; then
print_error "Docker is required to run Onyx."
exit 1
fi
install_docker_linux
if ! command -v docker &> /dev/null; then
print_error "Docker installation failed."
@@ -411,7 +423,11 @@ if command -v docker &> /dev/null \
&& ! command -v docker-compose &> /dev/null \
&& { [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; }; then
print_info "Docker Compose not found — installing plugin..."
print_info "Docker Compose is required but not installed."
if ! confirm_action "Docker Compose plugin"; then
print_error "Docker Compose is required to run Onyx."
exit 1
fi
COMPOSE_ARCH="$(uname -m)"
COMPOSE_URL="https://github.com/docker/compose/releases/latest/download/docker-compose-linux-${COMPOSE_ARCH}"
COMPOSE_DIR="/usr/local/lib/docker/cli-plugins"
@@ -562,10 +578,31 @@ version_compare() {
# Check Docker daemon
if ! docker info &> /dev/null; then
print_error "Docker daemon is not running. Please start Docker."
exit 1
if [[ "$OSTYPE" == "darwin"* ]]; then
print_info "Docker daemon is not running. Starting Docker Desktop..."
open -a Docker
# Wait up to 120 seconds for Docker to be ready
DOCKER_WAIT=0
DOCKER_MAX_WAIT=120
while ! docker info &> /dev/null; do
if [ $DOCKER_WAIT -ge $DOCKER_MAX_WAIT ]; then
print_error "Docker Desktop did not start within ${DOCKER_MAX_WAIT} seconds."
print_info "Please start Docker Desktop manually and re-run this script."
exit 1
fi
printf "\r\033[KWaiting for Docker Desktop to start... (%ds)" "$DOCKER_WAIT"
sleep 2
DOCKER_WAIT=$((DOCKER_WAIT + 2))
done
echo ""
print_success "Docker Desktop is now running"
else
print_error "Docker daemon is not running. Please start Docker."
exit 1
fi
else
print_success "Docker daemon is running"
fi
print_success "Docker daemon is running"
# Check Docker resources
print_step "Verifying Docker resources"
@@ -705,25 +742,48 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
fi
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo " 2) Standard - Full deployment with search, connectors, and RAG"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
print_info "Selected: Standard mode"
;;
*)
LITE_MODE=true
print_info "Selected: Lite mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Handle lite overlay file based on selected mode
if [[ "$LITE_MODE" = true ]]; then
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
print_warning "Existing lite overlay found but --lite was not passed."
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
LITE_MODE=true
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed lite overlay (switching to standard mode)"
fi
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
@@ -745,6 +805,7 @@ print_success "All configuration files ready"
# Set up deployment configuration
print_step "Setting up deployment configs"
ENV_FILE="${INSTALL_ROOT}/deployment/.env"
ENV_TEMPLATE="${INSTALL_ROOT}/deployment/env.template"
# Check if services are already running
if [ -d "${INSTALL_ROOT}/deployment" ] && [ -f "${INSTALL_ROOT}/deployment/docker-compose.yml" ]; then
# Determine compose command
@@ -785,22 +846,22 @@ if [ -f "$ENV_FILE" ]; then
if [ "$REPLY" = "update" ]; then
print_info "Update selected. Which tag would you like to deploy?"
echo ""
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
if [ "$INCLUDE_CRAFT" = true ]; then
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest version"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -852,45 +913,6 @@ else
print_info "No existing .env file found. Setting up new deployment..."
echo ""
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Standard - Full deployment with search, connectors, and RAG"
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
LITE_MODE=true
print_info "Selected: Lite mode"
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
;;
*)
print_info "Selected: Standard mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
# Validate lite + craft combination (could now be set interactively)
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
# Adjust resource expectations for lite mode
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Ask for version
print_info "Which tag would you like to deploy?"
echo ""
@@ -901,18 +923,18 @@ else
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest tag"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -1070,20 +1092,39 @@ fi
export HOST_PORT=$AVAILABLE_PORT
print_success "Using port $AVAILABLE_PORT for nginx"
# Determine if we're using the latest tag or a craft tag (both should force pull)
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
USE_LATEST=true
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
else
print_info "Using 'latest' tag - will force pull and recreate containers"
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
fi
else
USE_LATEST=false
fi
# For pinned version tags, re-download config files from that tag so the
# compose file matches the images being pulled (the initial download used main).
if [[ "$USE_LATEST" = false ]] && [[ "$USE_LOCAL_FILES" = false ]]; then
PINNED_BASE="https://raw.githubusercontent.com/onyx-dot-app/onyx/${CURRENT_IMAGE_TAG}/deployment"
print_info "Fetching config files matching tag ${CURRENT_IMAGE_TAG}..."
if download_file "${PINNED_BASE}/docker_compose/docker-compose.yml" "${INSTALL_ROOT}/deployment/docker-compose.yml" 2>/dev/null; then
download_file "${PINNED_BASE}/data/nginx/app.conf.template" "${INSTALL_ROOT}/data/nginx/app.conf.template" 2>/dev/null || true
download_file "${PINNED_BASE}/data/nginx/run-nginx.sh" "${INSTALL_ROOT}/data/nginx/run-nginx.sh" 2>/dev/null || true
chmod +x "${INSTALL_ROOT}/data/nginx/run-nginx.sh"
if [[ "$LITE_MODE" = true ]]; then
download_file "${PINNED_BASE}/docker_compose/${LITE_COMPOSE_FILE}" \
"${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" 2>/dev/null || true
fi
print_success "Config files updated to match ${CURRENT_IMAGE_TAG}"
else
print_warning "Tag ${CURRENT_IMAGE_TAG} not found on GitHub — using main branch configs"
fi
fi
# Pull Docker images with reduced output
print_step "Pulling Docker images"
print_info "This may take several minutes depending on your internet connection..."

View File

@@ -127,6 +127,7 @@ Inputs (common):
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
- `postgres_username`, `postgres_password`
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
### `vpc`
- Builds a VPC sized for EKS with multiple private and public subnets

View File

@@ -88,6 +88,8 @@ module "waf" {
tags = local.merged_tags
# WAF configuration with sensible defaults
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
geo_restriction_countries = var.waf_geo_restriction_countries

View File

@@ -117,6 +117,18 @@ variable "waf_rate_limit_requests_per_5_minutes" {
default = 2000
}
variable "waf_allowed_ip_cidrs" {
type = list(string)
description = "Optional IPv4 CIDR ranges allowed through the WAF. Leave empty to disable IP allowlisting."
default = []
}
variable "waf_common_rule_set_count_rules" {
type = list(string)
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
default = []
}
variable "waf_api_rate_limit_requests_per_5_minutes" {
type = number
description = "Rate limit for API requests per 5 minutes per IP address"

View File

@@ -1,6 +1,20 @@
locals {
name = var.name
tags = var.tags
name = var.name
tags = var.tags
ip_allowlist_enabled = length(var.allowed_ip_cidrs) > 0
managed_rule_priority = local.ip_allowlist_enabled ? 1 : 0
}
resource "aws_wafv2_ip_set" "allowed_ips" {
count = local.ip_allowlist_enabled ? 1 : 0
name = "${local.name}-allowed-ips"
description = "IP allowlist for ${local.name}"
scope = "REGIONAL"
ip_address_version = "IPV4"
addresses = var.allowed_ip_cidrs
tags = local.tags
}
# AWS WAFv2 Web ACL
@@ -13,10 +27,38 @@ resource "aws_wafv2_web_acl" "main" {
allow {}
}
dynamic "rule" {
for_each = local.ip_allowlist_enabled ? [1] : []
content {
name = "BlockRequestsOutsideAllowedIPs"
priority = 1
action {
block {}
}
statement {
not_statement {
statement {
ip_set_reference_statement {
arn = aws_wafv2_ip_set.allowed_ips[0].arn
}
}
}
}
visibility_config {
cloudwatch_metrics_enabled = true
metric_name = "BlockRequestsOutsideAllowedIPsMetric"
sampled_requests_enabled = true
}
}
}
# AWS Managed Rules - Core Rule Set
rule {
name = "AWSManagedRulesCommonRuleSet"
priority = 1
priority = 1 + local.managed_rule_priority
override_action {
none {}
@@ -26,6 +68,16 @@ resource "aws_wafv2_web_acl" "main" {
managed_rule_group_statement {
name = "AWSManagedRulesCommonRuleSet"
vendor_name = "AWS"
dynamic "rule_action_override" {
for_each = var.common_rule_set_count_rules
content {
name = rule_action_override.value
action_to_use {
count {}
}
}
}
}
}
@@ -39,7 +91,7 @@ resource "aws_wafv2_web_acl" "main" {
# AWS Managed Rules - Known Bad Inputs
rule {
name = "AWSManagedRulesKnownBadInputsRuleSet"
priority = 2
priority = 2 + local.managed_rule_priority
override_action {
none {}
@@ -62,7 +114,7 @@ resource "aws_wafv2_web_acl" "main" {
# Rate Limiting Rule
rule {
name = "RateLimitRule"
priority = 3
priority = 3 + local.managed_rule_priority
action {
block {}
@@ -87,7 +139,7 @@ resource "aws_wafv2_web_acl" "main" {
for_each = length(var.geo_restriction_countries) > 0 ? [1] : []
content {
name = "GeoRestrictionRule"
priority = 4
priority = 4 + local.managed_rule_priority
action {
block {}
@@ -110,7 +162,7 @@ resource "aws_wafv2_web_acl" "main" {
# IP Rate Limiting
rule {
name = "APIRateLimitRule"
priority = 5
priority = 5 + local.managed_rule_priority
action {
block {}
@@ -133,7 +185,7 @@ resource "aws_wafv2_web_acl" "main" {
# SQL Injection Protection
rule {
name = "AWSManagedRulesSQLiRuleSet"
priority = 6
priority = 6 + local.managed_rule_priority
override_action {
none {}
@@ -156,7 +208,7 @@ resource "aws_wafv2_web_acl" "main" {
# Anonymous IP Protection
rule {
name = "AWSManagedRulesAnonymousIpList"
priority = 7
priority = 7 + local.managed_rule_priority
override_action {
none {}

View File

@@ -9,6 +9,18 @@ variable "tags" {
default = {}
}
variable "allowed_ip_cidrs" {
type = list(string)
description = "Optional IPv4 CIDR ranges allowed to reach the application. Leave empty to disable IP allowlisting."
default = []
}
variable "common_rule_set_count_rules" {
type = list(string)
description = "Subrules within AWSManagedRulesCommonRuleSet to override to COUNT instead of BLOCK."
default = []
}
variable "rate_limit_requests_per_5_minutes" {
type = number
description = "Rate limit for requests per 5 minutes per IP address"

1
desktop/AGENTS.md Symbolic link
View File

@@ -0,0 +1 @@
../web/AGENTS.md

1
desktop/CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

View File

@@ -8,7 +8,7 @@
"name": "widget",
"version": "0.1.0",
"dependencies": {
"next": "^16.1.5",
"next": "^16.1.7",
"react": "^19",
"react-dom": "^19",
"react-markdown": "^10.1.0"
@@ -1023,9 +1023,9 @@
}
},
"node_modules/@next/env": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.5.tgz",
"integrity": "sha512-CRSCPJiSZoi4Pn69RYBDI9R7YK2g59vLexPQFXY0eyw+ILevIenCywzg+DqmlBik9zszEnw2HLFOUlLAcJbL7g==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
"license": "MIT"
},
"node_modules/@next/eslint-plugin-next": {
@@ -1039,9 +1039,9 @@
}
},
"node_modules/@next/swc-darwin-arm64": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.5.tgz",
"integrity": "sha512-eK7Wdm3Hjy/SCL7TevlH0C9chrpeOYWx2iR7guJDaz4zEQKWcS1IMVfMb9UKBFMg1XgzcPTYPIp1Vcpukkjg6Q==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
"cpu": [
"arm64"
],
@@ -1055,9 +1055,9 @@
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.5.tgz",
"integrity": "sha512-foQscSHD1dCuxBmGkbIr6ScAUF6pRoDZP6czajyvmXPAOFNnQUJu2Os1SGELODjKp/ULa4fulnBWoHV3XdPLfA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
"cpu": [
"x64"
],
@@ -1071,9 +1071,9 @@
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.5.tgz",
"integrity": "sha512-qNIb42o3C02ccIeSeKjacF3HXotGsxh/FMk/rSRmCzOVMtoWH88odn2uZqF8RLsSUWHcAqTgYmPD3pZ03L9ZAA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
"cpu": [
"arm64"
],
@@ -1087,9 +1087,9 @@
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.5.tgz",
"integrity": "sha512-U+kBxGUY1xMAzDTXmuVMfhaWUZQAwzRaHJ/I6ihtR5SbTVUEaDRiEU9YMjy1obBWpdOBuk1bcm+tsmifYSygfw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
"cpu": [
"arm64"
],
@@ -1103,9 +1103,9 @@
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.5.tgz",
"integrity": "sha512-gq2UtoCpN7Ke/7tKaU7i/1L7eFLfhMbXjNghSv0MVGF1dmuoaPeEVDvkDuO/9LVa44h5gqpWeJ4mRRznjDv7LA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
"cpu": [
"x64"
],
@@ -1119,9 +1119,9 @@
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.5.tgz",
"integrity": "sha512-bQWSE729PbXT6mMklWLf8dotislPle2L70E9q6iwETYEOt092GDn0c+TTNj26AjmeceSsC4ndyGsK5nKqHYXjQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
"cpu": [
"x64"
],
@@ -1135,9 +1135,9 @@
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.5.tgz",
"integrity": "sha512-LZli0anutkIllMtTAWZlDqdfvjWX/ch8AFK5WgkNTvaqwlouiD1oHM+WW8RXMiL0+vAkAJyAGEzPPjO+hnrSNQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
"cpu": [
"arm64"
],
@@ -1151,9 +1151,9 @@
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.5.tgz",
"integrity": "sha512-7is37HJTNQGhjPpQbkKjKEboHYQnCgpVt/4rBrrln0D9nderNxZ8ZWs8w1fAtzUx7wEyYjQ+/13myFgFj6K2Ng==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
"cpu": [
"x64"
],
@@ -2564,12 +2564,15 @@
"dev": true
},
"node_modules/baseline-browser-mapping": {
"version": "2.9.14",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.14.tgz",
"integrity": "sha512-B0xUquLkiGLgHhpPBqvl7GWegWBUNuujQ6kXd/r1U38ElPT6Ok8KZ8e+FpUGEc2ZoRQUzq/aUnaKFc/svWUGSg==",
"version": "2.10.8",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
"license": "Apache-2.0",
"bin": {
"baseline-browser-mapping": "dist/cli.js"
"baseline-browser-mapping": "dist/cli.cjs"
},
"engines": {
"node": ">=6.0.0"
}
},
"node_modules/brace-expansion": {
@@ -5926,14 +5929,14 @@
"dev": true
},
"node_modules/next": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.5.tgz",
"integrity": "sha512-f+wE+NSbiQgh3DSAlTaw2FwY5yGdVViAtp8TotNQj4kk4Q8Bh1sC/aL9aH+Rg1YAVn18OYXsRDT7U/079jgP7w==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
"license": "MIT",
"dependencies": {
"@next/env": "16.1.5",
"@next/env": "16.1.7",
"@swc/helpers": "0.5.15",
"baseline-browser-mapping": "^2.8.3",
"baseline-browser-mapping": "^2.9.19",
"caniuse-lite": "^1.0.30001579",
"postcss": "8.4.31",
"styled-jsx": "5.1.6"
@@ -5945,14 +5948,14 @@
"node": ">=20.9.0"
},
"optionalDependencies": {
"@next/swc-darwin-arm64": "16.1.5",
"@next/swc-darwin-x64": "16.1.5",
"@next/swc-linux-arm64-gnu": "16.1.5",
"@next/swc-linux-arm64-musl": "16.1.5",
"@next/swc-linux-x64-gnu": "16.1.5",
"@next/swc-linux-x64-musl": "16.1.5",
"@next/swc-win32-arm64-msvc": "16.1.5",
"@next/swc-win32-x64-msvc": "16.1.5",
"@next/swc-darwin-arm64": "16.1.7",
"@next/swc-darwin-x64": "16.1.7",
"@next/swc-linux-arm64-gnu": "16.1.7",
"@next/swc-linux-arm64-musl": "16.1.7",
"@next/swc-linux-x64-gnu": "16.1.7",
"@next/swc-linux-x64-musl": "16.1.7",
"@next/swc-win32-arm64-msvc": "16.1.7",
"@next/swc-win32-x64-msvc": "16.1.7",
"sharp": "^0.34.4"
},
"peerDependencies": {

View File

@@ -9,7 +9,7 @@
"lint": "next lint"
},
"dependencies": {
"next": "^16.1.5",
"next": "^16.1.7",
"react": "^19",
"react-dom": "^19",
"react-markdown": "^10.1.0"

View File

@@ -65,7 +65,7 @@
},
{
"scope": ["web/**"],
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/STANDARDS.md file."
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
},
{
"scope": [],
@@ -85,7 +85,7 @@
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
],
"files": [

Some files were not shown because too many files have changed in this diff Show More