Compare commits

...

66 Commits

Author SHA1 Message Date
Bo-Onyx
4da54f34b3 address comments 2026-03-23 13:11:52 -07:00
Bo-Onyx
5ebb80aeb2 address comments 2026-03-22 12:11:03 -07:00
Bo-Onyx
f77d5d2d01 feat(hook): integrate query processing hook point 2026-03-20 17:37:01 -07:00
Bo-Onyx
30e7a831a5 feat(hook): Add hook management API (#9513)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 22:53:59 +00:00
Evan Lohn
276261c96d fix: windows installer (#9507) 2026-03-20 22:53:46 +00:00
Bo-Onyx
205f1410e4 chore(hook): Hook executor. (#9467) 2026-03-20 22:47:01 +00:00
Bo-Onyx
a93d154c27 feat(hook): improve on hook point definition (#9522) 2026-03-20 22:20:42 +00:00
Jamison Lahman
1361879bd0 fix(fe): clicking outside chat area keeps chat input focused (#9521) 2026-03-20 19:22:11 +00:00
Justin Tahara
c58cc320b2 feat(tf): Port over WAF updates (#9520) 2026-03-20 18:45:09 +00:00
Jamison Lahman
461350958a fix(fe): dim project name in sidebar color (#9519) 2026-03-20 17:47:49 +00:00
Raunak Bhagat
50dde0be1a chore: edit AGENTS.md and CLAUDE.md files (#9486) 2026-03-20 00:59:30 +00:00
acaprau
199e1df453 feat(opensearch): Add functions for keyword and semantic retrieval (#9479)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-20 00:48:01 +00:00
Justin Tahara
996b674840 feat(backend): Adding procps (#9509) 2026-03-19 23:26:36 +00:00
Justin Tahara
5413723ccc feat(ods): Rerun run-ci workflow (#9501) 2026-03-19 22:11:59 +00:00
Evan Lohn
9660056a51 fix: drive rate limit retry (#9498) 2026-03-19 21:32:08 +00:00
Fizza Mukhtar
3105177238 fix(llm): don't send tool_choice when no tools are provided (#9224) 2026-03-19 21:26:46 +00:00
Evan Lohn
24bb4bda8b feat: windows installer and install improvements (#9476) 2026-03-19 20:47:44 +00:00
Raunak Bhagat
9532af4ceb chore: move Hoverable story (#9495) 2026-03-19 20:40:27 +00:00
Jamison Lahman
0a913f6af5 fix(fe): fix memories immediately losing focus on click (#9493) 2026-03-19 20:15:34 +00:00
Justin Tahara
fe30c55199 fix(code interpreter): Caching files (#9484) 2026-03-19 19:32:37 +00:00
Jamison Lahman
2cf0a65dd3 chore(fe): reduce padding on elements at the bottom of modal headers (#9488) 2026-03-19 19:27:37 +00:00
Nikolas Garza
659416f363 feat(admin): groups page - list page and group cards (#9453) 2026-03-19 18:23:15 +00:00
Raunak Bhagat
40aecbc4b9 refactor(fe): move table to opal, update size API (#9438) 2026-03-19 17:23:41 +00:00
Jamison Lahman
710b39074f chore(fe): remove opal-button* class names (#9471) 2026-03-19 02:15:00 +00:00
acaprau
8fe2f67d38 chore(opensearch): Allow disabling match highlights via env var; default to disabled (#9436) 2026-03-19 00:43:17 +00:00
Justin Tahara
f00aaf9fc0 fix(agents): Agents are Private by Default (#9465) 2026-03-19 00:01:46 +00:00
Bo-Onyx
5b2426b002 chore(hooks): Define Hook Point in the backend (#9391) 2026-03-18 23:43:26 +00:00
Justin Tahara
ba6ab0245b fix(celery): add dedup guardrails to user file delete queue (#9454)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 23:38:52 +00:00
Justin Tahara
b64ebb57e1 fix(logging): extract LiteLLM error details in image summarization failures (#9458) 2026-03-18 23:29:04 +00:00
Justin Tahara
2fcfdbabde fix(celery): add task expiry to upload API send_task call (#9456)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 23:17:08 +00:00
Justin Tahara
ea1a2749c1 fix(image): add diagnostic logging to vision model selection (#9460) 2026-03-18 22:06:56 +00:00
Justin Tahara
73c4e22588 fix(image): stop dumping base64 image data into error logs (#9457) 2026-03-18 21:43:55 +00:00
Jamison Lahman
fceaac6e13 fix(fe): make indexing attempt error rows click to show trace (#9463)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-18 21:38:53 +00:00
Jamison Lahman
e8bf45cfd2 feat(fe): "Full Exception Trace" modal uses CodePreview rendering (#9464)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-18 21:04:55 +00:00
Bo-Onyx
13ff648fcd chore(hooks): Add Celery task to remove hook running records older than 30 days (#9433) 2026-03-18 21:03:01 +00:00
Jamison Lahman
ae8268afb1 fix(fe): truncate connector names in table (#9459) 2026-03-18 20:59:49 +00:00
acaprau
b338bd9e97 feat(opensearch): Can override number of shards and replicas via env var (#9431) 2026-03-18 20:16:05 +00:00
acaprau
0dcc90a042 fix(opensearch): Exclude retrieving vectors during hybrid and random search (#9430) 2026-03-18 20:13:12 +00:00
Jamison Lahman
0f6a6693d3 fix(fe): truncate project name in sidebar button (#9462) 2026-03-18 20:06:09 +00:00
Jamison Lahman
e32cc450b2 fix(fe): update connector indexing error modal (#9426)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-18 11:57:28 -07:00
Jamison Lahman
732fb71edf chore(tests): unit tests for pdf processing (#9452)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 18:31:37 +00:00
dependabot[bot]
ca3320c0e0 chore(deps): bump pypdf from 6.8.0 to 6.9.1 (#9450)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-18 17:52:50 +00:00
Jamison Lahman
d7c554aca7 chore(ruff): fix and enable S324 (#9451) 2026-03-18 17:26:29 +00:00
dependabot[bot]
69e5c19695 chore(deps): bump next from 16.1.5 to 16.1.7 in /examples/widget (#9425)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-18 09:25:27 -07:00
Nikolas Garza
b4ce1c7a97 chore: bump next to 16.1.7 (#9423) 2026-03-18 09:22:40 -07:00
Jamison Lahman
cd64a91154 fix(fe): display name on attachment file card hover (#9446) 2026-03-18 16:13:21 +00:00
Danelegend
c282cdc096 fix(file upload): Allow zip file upload via query param (#9432) 2026-03-18 07:32:07 +00:00
Jamison Lahman
b1de1c59b6 chore(playwright): projects screenshot is main container only (#9440) 2026-03-18 05:35:30 +00:00
acaprau
64d484039f chore(opensearch): Disable test_update_single_can_clear_user_projects_and_personas (#9434) 2026-03-18 00:40:29 +00:00
Jamison Lahman
0530095b71 fix(fe): replace users table buttons with LineItems (#9435) 2026-03-17 23:45:15 +00:00
acaprau
23280d5b91 fix(opensearch): Fix env var mismatch issue with configuring subquery results; set default to 100 (#9428) 2026-03-17 16:01:45 -07:00
Bo-Onyx
229442679c chore(hooks): Add db CRUD (#9411) 2026-03-17 22:36:50 +00:00
Jamison Lahman
95a192fb0f chore(devtools): upgrade ods: 0.7.0->0.7.1 (#9429) 2026-03-17 15:18:21 -07:00
Yuhong Sun
6bd96ec906 chore: Scripts for search quality eval (#9421) 2026-03-17 14:53:32 -07:00
Jamison Lahman
a1ec88269f chore(docker): configurable api_server resource limits (#9424)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-17 14:52:21 -07:00
Raunak Bhagat
b929518c34 refactor: update size variant and dimension names in opal components (#9416)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-03-17 21:43:19 +00:00
Justin Tahara
479220e774 chore(opensearch): Make Password Default Empty (#9415) 2026-03-17 21:41:08 +00:00
dependabot[bot]
d3e0acf905 chore(deps): bump pyasn1 from 0.6.2 to 0.6.3 (#9417)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-17 21:28:07 +00:00
acaprau
cbd1a344f2 feat(opensearch): Configure hybrid search subquery groups and pipelines via env var (#9407) 2026-03-17 21:25:53 +00:00
Jamison Lahman
e79264b69b chore(fe): prefer INTERNAL_URL (#9419) 2026-03-17 14:17:05 -07:00
Justin Tahara
1e0a8e9a0e fix(llm): surface masked OpenAI quota failures (#9308) 2026-03-17 21:11:43 +00:00
Jamison Lahman
b7841a513d chore(devtools): ods backend scans for available ports (#9418) 2026-03-17 14:09:23 -07:00
dependabot[bot]
c779bf722d chore(deps): bump next from 16.1.5 to 16.1.7 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9420)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-17 14:03:51 -07:00
Jamison Lahman
a5aff0d199 chore(fe): rm dead stackTraceModalContent (#9412) 2026-03-17 20:42:29 +00:00
Wenxi
8ed170b070 fix: make ConnectorCredentialPair name required (#9408)
Co-authored-by: Ciaran Sweet <ciaran@developmentseed.org>
2026-03-17 18:54:34 +00:00
Raunak Bhagat
c890cd4767 feat(llm-config): replace AdvancedOptions with unified ModelsAccessField (#9270)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 18:19:52 +00:00
258 changed files with 14572 additions and 13186 deletions

279
AGENTS.md
View File

@@ -167,284 +167,7 @@ web/
## Frontend Standards
### 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
### 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
### 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
### 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
### 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
### 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
### 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
### 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
### 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
### 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
### 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
### 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
## Database & Migrations

View File

@@ -47,6 +47,8 @@ RUN apt-get update && \
gcc \
nano \
vim \
# Install procps so kubernetes exec sessions can use ps aux for debugging
procps \
libjemalloc2 \
&& \
rm -rf /var/lib/apt/lists/* && \

View File

@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -0,0 +1,35 @@
from celery import shared_task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import cleanup_old_execution_logs__no_commit
from onyx.utils.logger import setup_logger
logger = setup_logger()
_HOOK_EXECUTION_LOG_RETENTION_DAYS: int = 30
@shared_task(
name=OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def hook_execution_log_cleanup_task(*, tenant_id: str) -> None: # noqa: ARG001
try:
with get_session_with_current_tenant() as db_session:
deleted: int = cleanup_old_execution_logs__no_commit(
db_session=db_session,
max_age_days=_HOOK_EXECUTION_LOG_RETENTION_DAYS,
)
db_session.commit()
if deleted:
logger.info(
f"Deleted {deleted} hook execution log(s) older than "
f"{_HOOK_EXECUTION_LOG_RETENTION_DAYS} days."
)
except Exception:
logger.exception("Failed to clean up hook execution logs")
raise

View File

@@ -24,6 +24,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
@@ -33,6 +34,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
@@ -91,6 +93,17 @@ def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a delete_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
@@ -546,7 +559,23 @@ def process_single_user_file(
ignore_result=True,
)
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with DELETING status and enqueue per-file tasks."""
"""Scan for user files with DELETING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway (mirrors check_user_file_processing):
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still DELETING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
"""
task_logger.info("check_for_user_file_delete - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -555,8 +584,23 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
)
if not lock.acquire(blocking=False):
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_for_user_file_delete - Queue depth {queue_len} exceeds "
f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -568,23 +612,40 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_delete_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
try:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
except Exception as e:
task_logger.exception(
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
)
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}"
)
return None
@@ -602,6 +663,9 @@ def delete_user_file_impl(
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the queued guard so the beat can re-enqueue if deletion fails
# and the file remains in DELETING status.
redis_client.delete(_user_file_delete_queued_key(user_file_id))
file_lock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,

View File

@@ -297,7 +297,9 @@ class PostgresCacheBackend(CacheBackend):
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
h = hashlib.md5(
f"{self._tenant_id}:{name}".encode(), usedforsecurity=False
).digest()
return struct.unpack("q", h[:8])[0]

View File

@@ -36,9 +36,11 @@ from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import is_true_openai_model
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
@@ -72,6 +74,70 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
class EmptyLLMResponseError(RuntimeError):
"""Raised when the streamed LLM response completes without a usable answer."""
def __init__(
self,
*,
provider: str,
model: str,
tool_choice: ToolChoiceOptions,
client_error_msg: str,
error_code: str = "EMPTY_LLM_RESPONSE",
is_retryable: bool = True,
) -> None:
super().__init__(client_error_msg)
self.provider = provider
self.model = model
self.tool_choice = tool_choice
self.client_error_msg = client_error_msg
self.error_code = error_code
self.is_retryable = is_retryable
def _build_empty_llm_response_error(
llm: LLM,
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
) -> EmptyLLMResponseError:
provider = llm.config.model_provider
model = llm.config.model_name
# OpenAI quota exhaustion has reached us as a streamed "stop" with zero content.
# When the stream is completely empty and there is no reasoning/tool output, surface
# the likely account-level cause instead of a generic tool-calling error.
if (
not llm_step_result.reasoning
and provider == LlmProviderNames.OPENAI
and is_true_openai_model(provider, model)
):
return EmptyLLMResponseError(
provider=provider,
model=model,
tool_choice=tool_choice,
client_error_msg=(
"The selected OpenAI model returned an empty streamed response "
"before producing any tokens. This commonly happens when the API "
"key or project has no remaining quota or billing is not enabled. "
"Verify quota and billing for this key and try again."
),
error_code="BUDGET_EXCEEDED",
is_retryable=False,
)
return EmptyLLMResponseError(
provider=provider,
model=model,
tool_choice=tool_choice,
client_error_msg=(
"The selected model returned no final answer before the stream "
"completed. No text or tool calls were received from the upstream "
"provider."
),
)
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
"""Detect XML-style marshaled tool calls emitted as plain text."""
if not text:
@@ -613,7 +679,12 @@ def run_llm_loop(
)
citation_processor.update_citation_mapping(project_citation_mapping)
llm_step_result: LlmStepResult | None = None
llm_step_result = LlmStepResult(
reasoning=None,
answer=None,
tool_calls=None,
raw_answer=None,
)
# Pass the total budget to construct_message_history, which will handle token allocation
available_tokens = llm.config.max_input_tokens
@@ -1084,12 +1155,18 @@ def run_llm_loop(
# As long as 1 tool with citeable documents is called at any point, we ask the LLM to try to cite
should_cite_documents = True
if not llm_step_result or not llm_step_result.answer:
if not llm_step_result.answer and not llm_step_result.tool_calls:
raise _build_empty_llm_response_error(
llm=llm,
llm_step_result=llm_step_result,
tool_choice=tool_choice,
)
if not llm_step_result.answer:
raise RuntimeError(
"The LLM did not return an answer. "
"Typically this is an issue with LLMs that do not support tool calling natively, "
"or the model serving API is not configured correctly. "
"This may also happen with models that are lower quality outputting invalid tool calls."
"The LLM did not return a final answer after tool execution. "
"Typically this indicates invalid tool-call output, a model/provider mismatch, "
"or serving API misconfiguration."
)
emitter.emit(

View File

@@ -1013,6 +1013,10 @@ def run_llm_step_pkt_generator(
accumulated_reasoning = ""
accumulated_answer = ""
accumulated_raw_answer = ""
stream_chunk_count = 0
actionable_chunk_count = 0
empty_chunk_count = 0
finish_reasons: set[str] = set()
xml_tool_call_content_filter = _XmlToolCallContentFilter()
processor_state: Any = None
@@ -1145,6 +1149,7 @@ def run_llm_step_pkt_generator(
user_identity=user_identity,
timeout_override=timeout_override,
):
stream_chunk_count += 1
if packet.usage:
usage = packet.usage
span_generation.span_data.usage = {
@@ -1154,16 +1159,21 @@ def run_llm_step_pkt_generator(
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
# Note: LLM cost tracking is now handled in multi_llm.py
finish_reason = packet.choice.finish_reason
if finish_reason:
finish_reasons.add(str(finish_reason))
delta = packet.choice.delta
# Weird behavior from some model providers, just log and ignore for now
if (
delta.content is None
not delta.content
and delta.reasoning_content is None
and delta.tool_calls is None
and not delta.tool_calls
):
empty_chunk_count += 1
logger.warning(
f"LLM packet is empty (no contents, reasoning or tool calls). Skipping: {packet}"
"LLM packet is empty (no content, reasoning, or tool calls). "
f"finish_reason={finish_reason}. Skipping: {packet}"
)
continue
@@ -1172,6 +1182,8 @@ def run_llm_step_pkt_generator(
time.monotonic() - stream_start_time
)
first_action_recorded = True
if _delta_has_action(delta):
actionable_chunk_count += 1
if custom_token_processor:
# The custom token processor can modify the deltas for specific custom logic
@@ -1307,6 +1319,15 @@ def run_llm_step_pkt_generator(
else:
logger.debug("Tool calls: []")
if actionable_chunk_count == 0:
logger.warning(
"LLM stream completed with no actionable deltas. "
f"chunks={stream_chunk_count}, empty_chunks={empty_chunk_count}, "
f"finish_reasons={sorted(finish_reasons)}, "
f"provider={llm.config.model_provider}, model={llm.config.model_name}, "
f"tool_choice={tool_choice}, tools_sent={len(tool_definitions)}"
)
return (
LlmStepResult(
reasoning=accumulated_reasoning if accumulated_reasoning else None,

View File

@@ -29,6 +29,7 @@ from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
from onyx.chat.models import AnswerStream
from onyx.chat.models import ChatBasicResponse
@@ -58,6 +59,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -67,11 +69,19 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -423,6 +433,32 @@ def determine_search_params(
)
def _resolve_query_processing_hook_result(
hook_result: BaseModel | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not isinstance(hook_result, QueryProcessingResponse):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Expected QueryProcessingResponse from hook, got {type(hook_result).__name__}",
)
if not (hook_result.query and hook_result.query.strip()):
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message or "Your query was rejected.",
)
return hook_result.query.strip()
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -483,6 +519,7 @@ def handle_stream_message_objects(
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -574,6 +611,27 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
# Skip the hook for empty/whitespace-only messages — no meaningful query
# to process, and SendMessageRequest.message has no min_length guard.
if message_text.strip():
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
)
message_text = _resolve_query_processing_hook_result(
hook_result, message_text
)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -913,6 +971,17 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -925,9 +994,28 @@ def handle_stream_message_objects(
db_session.rollback()
return
except EmptyLLMResponseError as e:
stack_trace = traceback.format_exc()
logger.warning(
"LLM returned an empty response "
f"(provider={e.provider}, model={e.model}, tool_choice={e.tool_choice})"
)
yield StreamingError(
error=e.client_error_msg,
stack_trace=stack_trace,
error_code=e.error_code,
is_retryable=e.is_retryable,
details={
"model": e.model,
"provider": e.provider,
"tool_choice": e.tool_choice.value,
},
)
db_session.rollback()
except Exception as e:
logger.exception(f"Failed to process chat message due to {e}")
error_msg = str(e)
stack_trace = traceback.format_exc()
if llm:
@@ -1046,10 +1134,46 @@ def llm_loop_completion_handle(
)
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
_CITATION_LINK_START_PATTERN = re.compile(r"\s*\[\[\d+\]\]\(")
return re.sub(pattern, "", answer)
def _find_markdown_link_end(text: str, destination_start: int) -> int | None:
depth = 0
i = destination_start
while i < len(text):
curr = text[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return i
depth -= 1
i += 1
return None
def remove_answer_citations(answer: str) -> str:
stripped_parts: list[str] = []
cursor = 0
while match := _CITATION_LINK_START_PATTERN.search(answer, cursor):
stripped_parts.append(answer[cursor : match.start()])
link_end = _find_markdown_link_end(answer, match.end())
if link_end is None:
stripped_parts.append(answer[match.start() :])
return "".join(stripped_parts)
cursor = link_end + 1
stripped_parts.append(answer[cursor:])
return "".join(stripped_parts)
@log_function_time()
@@ -1087,8 +1211,11 @@ def gather_stream(
raise ValueError("Message ID is required")
if answer is None:
# This should never be the case as these non-streamed flows do not have a stop-generation signal
raise RuntimeError("Answer was not generated")
if error_msg is not None:
answer = ""
else:
# This should never be the case as these non-streamed flows do not have a stop-generation signal
raise RuntimeError("Answer was not generated")
return ChatBasicResponse(
answer=answer,

View File

@@ -278,14 +278,17 @@ USING_AWS_MANAGED_OPENSEARCH = (
OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# Whether to disable match highlights for OpenSearch. Defaults to True for now
# as we investigate query performance.
OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED = (
os.environ.get("OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED", "true").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
@@ -318,8 +321,16 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
)
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
# If set, will override the default number of shards and replicas for the index.
OPENSEARCH_INDEX_NUM_SHARDS: int | None = (
int(os.environ["OPENSEARCH_INDEX_NUM_SHARDS"])
if os.environ.get("OPENSEARCH_INDEX_NUM_SHARDS", None) is not None
else None
)
OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
int(os.environ["OPENSEARCH_INDEX_NUM_REPLICAS"])
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
else None
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"

View File

@@ -177,6 +177,14 @@ USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
# How long a queued user-file-delete task is valid before workers discard it.
# Mirrors the processing task expiry to prevent indefinite queue growth when
# files are stuck in DELETING status and the beat keeps re-enqueuing them.
CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before the delete beat stops enqueuing more delete tasks.
USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -469,6 +477,9 @@ class OnyxRedisLocks:
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
# Short-lived key set when a delete task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a delete task is already queued.
USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued"
# Release notes
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
@@ -597,6 +608,9 @@ class OnyxCeleryTask:
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
# Hook execution log retention
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
# Sandbox cleanup
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"

View File

@@ -157,9 +157,7 @@ def _execute_single_retrieval(
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
results = _execute_with_retry(retrieval_function(**request_kwargs))
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")

View File

@@ -511,7 +511,7 @@ def add_credential_to_connector(
user: User,
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
cc_pair_name: str,
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,

233
backend/onyx/db/hook.py Normal file
View File

@@ -0,0 +1,233 @@
import datetime
from uuid import UUID
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.models import Hook
from onyx.db.models import HookExecutionLog
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ── Hook CRUD ────────────────────────────────────────────────────────────
def get_hook_by_id(
*,
db_session: Session,
hook_id: int,
include_deleted: bool = False,
include_creator: bool = False,
) -> Hook | None:
stmt = select(Hook).where(Hook.id == hook_id)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_non_deleted_hook_by_hook_point(
*,
db_session: Session,
hook_point: HookPoint,
include_creator: bool = False,
) -> Hook | None:
stmt = (
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
)
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_hooks(
*,
db_session: Session,
include_deleted: bool = False,
include_creator: bool = False,
) -> list[Hook]:
stmt = select(Hook)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
return list(db_session.scalars(stmt).all())
def create_hook__no_commit(
*,
db_session: Session,
name: str,
hook_point: HookPoint,
endpoint_url: str | None = None,
api_key: str | None = None,
fail_strategy: HookFailStrategy,
timeout_seconds: float,
is_active: bool = False,
creator_id: UUID | None = None,
) -> Hook:
"""Create a new hook for the given hook point.
At most one non-deleted hook per hook point is allowed. Raises
OnyxError(CONFLICT) if a hook already exists, including under concurrent
duplicate creates where the partial unique index fires an IntegrityError.
"""
existing = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if existing:
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
)
hook = Hook(
name=name,
hook_point=hook_point,
endpoint_url=endpoint_url,
api_key=api_key,
fail_strategy=fail_strategy,
timeout_seconds=timeout_seconds,
is_active=is_active,
creator_id=creator_id,
)
# Use a savepoint so that a failed insert only rolls back this operation,
# not the entire outer transaction.
savepoint = db_session.begin_nested()
try:
db_session.add(hook)
savepoint.commit()
except IntegrityError as exc:
savepoint.rollback()
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists.",
)
raise # re-raise unrelated integrity errors (FK violations, etc.)
return hook
def update_hook__no_commit(
*,
db_session: Session,
hook_id: int,
name: str | None = None,
endpoint_url: str | None | UnsetType = UNSET,
api_key: str | None | UnsetType = UNSET,
fail_strategy: HookFailStrategy | None = None,
timeout_seconds: float | None = None,
is_active: bool | None = None,
is_reachable: bool | None = None,
include_creator: bool = False,
) -> Hook:
"""Update hook fields.
Sentinel conventions:
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
"""
hook = get_hook_by_id(
db_session=db_session, hook_id=hook_id, include_creator=include_creator
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
if name is not None:
hook.name = name
if not isinstance(endpoint_url, UnsetType):
hook.endpoint_url = endpoint_url
if not isinstance(api_key, UnsetType):
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
if fail_strategy is not None:
hook.fail_strategy = fail_strategy
if timeout_seconds is not None:
hook.timeout_seconds = timeout_seconds
if is_active is not None:
hook.is_active = is_active
if is_reachable is not None:
hook.is_reachable = is_reachable
db_session.flush()
return hook
def delete_hook__no_commit(
*,
db_session: Session,
hook_id: int,
) -> None:
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
hook.deleted = True
hook.is_active = False
db_session.flush()
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
def create_hook_execution_log__no_commit(
*,
db_session: Session,
hook_id: int,
is_success: bool,
error_message: str | None = None,
status_code: int | None = None,
duration_ms: int | None = None,
) -> HookExecutionLog:
log = HookExecutionLog(
hook_id=hook_id,
is_success=is_success,
error_message=error_message,
status_code=status_code,
duration_ms=duration_ms,
)
db_session.add(log)
db_session.flush()
return log
def get_hook_execution_logs(
*,
db_session: Session,
hook_id: int,
limit: int,
) -> list[HookExecutionLog]:
stmt = (
select(HookExecutionLog)
.where(HookExecutionLog.hook_id == hook_id)
.order_by(HookExecutionLog.created_at.desc())
.limit(limit)
)
return list(db_session.scalars(stmt).all())
def cleanup_old_execution_logs__no_commit(
*,
db_session: Session,
max_age_days: int,
) -> int:
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=max_age_days
)
result: CursorResult = db_session.execute( # type: ignore[assignment]
delete(HookExecutionLog)
.where(HookExecutionLog.created_at < cutoff)
.execution_options(synchronize_session=False)
)
return result.rowcount

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -144,6 +145,7 @@ def upload_files_to_user_files_with_indexing(
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"

View File

@@ -2,6 +2,7 @@ import time
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
if DISABLE_VECTOR_DB:
return None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from contextlib import AbstractContextManager
@@ -18,6 +19,7 @@ from onyx.configs.app_configs import OPENSEARCH_HOST
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
from onyx.utils.logger import setup_logger
@@ -56,8 +58,8 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
# Score explanation from OpenSearch when "explain": true is set in the
# query. Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
@@ -833,9 +835,13 @@ class OpenSearchIndexClient(OpenSearchClient):
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
) -> list[SearchHit[DocumentChunk]]:
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
"""Searches the index.
NOTE: Does not return vector fields. In order to take advantage of
performance benefits, the search body should exclude the schema's vector
fields.
TODO(andrei): Ideally we could check that every field in the body is
present in the index, to avoid a class of runtime bugs that could easily
be caught during development. Or change the function signature to accept
@@ -883,7 +889,7 @@ class OpenSearchIndexClient(OpenSearchClient):
raise_on_timeout=True,
)
search_hits: list[SearchHit[DocumentChunk]] = []
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
for hit in hits:
document_chunk_source: dict[str, Any] | None = hit.get("_source")
if not document_chunk_source:
@@ -893,8 +899,10 @@ class OpenSearchIndexClient(OpenSearchClient):
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
search_hit = SearchHit[DocumentChunkWithoutVectors](
document_chunk=DocumentChunkWithoutVectors.model_validate(
document_chunk_source
),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
@@ -1055,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
f"Body: {get_new_body_without_vectors(body)}\n"
f"Search pipeline ID: {search_pipeline_id}\n"
f"Phase took: {phase_took}\n"
f"Profile: {profile}\n"
f"Profile: {json.dumps(profile, indent=2)}\n"
)
if timed_out:
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."

View File

@@ -1,12 +1,23 @@
# Default value for the maximum number of tokens a chunk can hold, if none is
# specified when creating an index.
from onyx.configs.app_configs import (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
import os
from enum import Enum
DEFAULT_MAX_CHUNK_SIZE = 512
# By default OpenSearch will only return a maximum of this many results in a
# given search. This value is configurable in the index settings.
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
# that the document was last updated this many days ago for the purpose of time
# cutoff filtering during retrieval.
ASSUMED_DOCUMENT_AGE_DAYS = 90
# Size of the dynamic list used to consider elements during kNN graph creation.
# Higher values improve search quality but increase indexing time. Values
# typically range between 100 - 512.
@@ -26,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
# we have a much higher chance of all 10 of the final desired docs showing up
# and getting scored. In worse situations, the final 10 docs don't even show up
# as the final 10 (worse than just a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
else 750
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
# poor search performance.
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
)
# Number of vectors to examine to decide the top k neighbors for the HNSW
@@ -39,23 +50,43 @@ DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
# larger than k, you can provide the size parameter to limit the final number of
# results to k." from
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
# Since the titles are included in the contents, the embedding matches are
# heavily downweighted as they act as a boost rather than an independent scoring
# component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title
# keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
# NOTE: It is critical that the order of these weights matches the order of the
# sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
]
class HybridSearchSubqueryConfiguration(Enum):
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
# Current default.
CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 2
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0
# Will raise and block application start if HYBRID_SEARCH_SUBQUERY_CONFIGURATION
# is set but not a valid value. If not set, defaults to
# CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD.
HYBRID_SEARCH_SUBQUERY_CONFIGURATION: HybridSearchSubqueryConfiguration = (
HybridSearchSubqueryConfiguration(
int(os.environ["HYBRID_SEARCH_SUBQUERY_CONFIGURATION"])
)
if os.environ.get("HYBRID_SEARCH_SUBQUERY_CONFIGURATION", None) is not None
else HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
)
class HybridSearchNormalizationPipeline(Enum):
# Current default.
MIN_MAX = 1
# NOTE: Using z-score normalization is better for hybrid search from a
# theoretical standpoint. Empirically on a small dataset of up to 10K docs,
# it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
ZSCORE = 2
# Will raise and block application start if HYBRID_SEARCH_NORMALIZATION_PIPELINE
# is set but not a valid value. If not set, defaults to MIN_MAX.
HYBRID_SEARCH_NORMALIZATION_PIPELINE: HybridSearchNormalizationPipeline = (
HybridSearchNormalizationPipeline(
int(os.environ["HYBRID_SEARCH_NORMALIZATION_PIPELINE"])
)
if os.environ.get("HYBRID_SEARCH_NORMALIZATION_PIPELINE", None) is not None
else HybridSearchNormalizationPipeline.MIN_MAX
)

View File

@@ -47,6 +47,7 @@ from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
@@ -55,16 +56,13 @@ from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
get_min_max_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_NAME,
get_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
from onyx.document_index.opensearch.search import (
ZSCORE_NORMALIZATION_PIPELINE_NAME,
get_zscore_normalization_pipeline_name_and_config,
)
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import Document
@@ -103,18 +101,24 @@ def set_cluster_state(client: OpenSearchClient) -> None:
"is not the first time running Onyx against this instance of OpenSearch, these "
"settings have likely already been set. Not taking any further action..."
)
client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
min_max_normalization_pipeline_name, min_max_normalization_pipeline_config = (
get_min_max_normalization_pipeline_name_and_config()
)
zscore_normalization_pipeline_name, zscore_normalization_pipeline_config = (
get_zscore_normalization_pipeline_name_and_config()
)
client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
pipeline_id=min_max_normalization_pipeline_name,
pipeline_body=min_max_normalization_pipeline_config,
)
client.create_search_pipeline(
pipeline_id=zscore_normalization_pipeline_name,
pipeline_body=zscore_normalization_pipeline_config,
)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
chunk: DocumentChunkWithoutVectors,
score: float | None,
highlights: dict[str, list[str]],
) -> InferenceChunkUncleaned:
@@ -877,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
results: list[InferenceChunk] = []
for chunk_request in chunk_requests:
search_hits: list[SearchHit[DocumentChunk]] = []
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
query_body = DocumentQuery.get_from_document_id_query(
document_id=chunk_request.document_id,
tenant_state=self._tenant_state,
@@ -940,17 +944,92 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
include_hidden=False,
)
# NOTE: Using z-score normalization here because it's better for hybrid
# search from a theoretical standpoint. Empirically on a small dataset
# of up to 10K docs, it's not very different. Likely more impactful at
# scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
normalization_pipeline_name, _ = get_normalization_pipeline_name_and_config()
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
search_pipeline_id=normalization_pipeline_name,
)
# Good place for a breakpoint to inspect the search hits if you have
# "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_keyword_search_query(
query_text=query,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
)
for search_hit in search_hits
]
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
inference_chunks_uncleaned
)
return inference_chunks
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
logger.debug(
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
query_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_embedding,
num_hits=num_to_retrieve,
tenant_state=self._tenant_state,
# NOTE: Index filters includes metadata tags which were filtered
# for invalid unicode at indexing time. In theory it would be
# ideal to do filtering here as well, in practice we never did
# that in the Vespa codepath and have not seen issues in
# production, so we deliberately conform to the existing logic
# in order to not unknowningly introduce a possible bug.
index_filters=filters,
include_hidden=False,
)
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
@@ -977,7 +1056,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)

View File

@@ -11,6 +11,8 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_REPLICAS
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_SHARDS
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.document_index.interfaces_new import TenantState
@@ -100,9 +102,9 @@ def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
return value
class DocumentChunk(BaseModel):
class DocumentChunkWithoutVectors(BaseModel):
"""
Represents a chunk of a document in the OpenSearch index.
Represents a chunk of a document in the OpenSearch index without vectors.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
@@ -124,9 +126,7 @@ class DocumentChunk(BaseModel):
# Either both should be None or both should be non-None.
title: str | None = None
title_vector: list[float] | None = None
content: str
content_vector: list[float]
source_type: str
# A list of key-value pairs separated by INDEX_SEPARATOR. See
@@ -176,19 +176,9 @@ class DocumentChunk(BaseModel):
def __str__(self) -> str:
return (
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
f"tenant_id={self.tenant_id.tenant_id})"
f"content length={len(self.content)}, tenant_id={self.tenant_id.tenant_id})."
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
# title and title_vector should both either be None or not.
if self.title is not None and self.title_vector is None:
raise ValueError("Bug: Title vector must not be None if title is not None.")
if self.title_vector is not None and self.title is None:
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
@model_serializer(mode="wrap")
def serialize_model(
self, handler: SerializerFunctionWrapHandler
@@ -305,6 +295,35 @@ class DocumentChunk(BaseModel):
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
class DocumentChunk(DocumentChunkWithoutVectors):
"""Represents a chunk of a document in the OpenSearch index.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
"""
model_config = {"frozen": True}
title_vector: list[float] | None = None
content_vector: list[float]
def __str__(self) -> str:
return (
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
f"tenant_id={self.tenant_id.tenant_id})"
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
# title and title_vector should both either be None or not.
if self.title is not None and self.title_vector is None:
raise ValueError("Bug: Title vector must not be None if title is not None.")
if self.title_vector is not None and self.title is None:
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
class DocumentSchema:
"""
Represents the schema and indexing strategies of the OpenSearch index.
@@ -516,78 +535,35 @@ class DocumentSchema:
return schema
@staticmethod
def get_index_settings() -> dict[str, Any]:
"""
Standard settings for reasonable local index and search performance.
"""
return {
"index": {
"number_of_shards": 1,
"number_of_replicas": 1,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch.
Our AWS-managed OpenSearch cluster has 3 data nodes in 3 availability
zones.
- We use 3 shards to distribute load across all data nodes.
- We use 2 replicas to ensure each shard has a copy in each
availability zone. This is a hard requirement from AWS. The number
of data copies, including the primary (not a replica) copy, must be
divisible by the number of AZs.
"""
return {
"index": {
"number_of_shards": 3,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch in multi-tenant cloud.
324 shards very roughly targets a storage load of ~30Gb per shard, which
according to AWS OpenSearch documentation is within a good target range.
As documented above we need 2 replicas for a total of 3 copies of the
data because the cluster is configured with 3-AZ awareness.
"""
return {
"index": {
"number_of_shards": 324,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_based_on_environment() -> dict[str, Any]:
"""
Returns the index settings based on the environment.
"""
if USING_AWS_MANAGED_OPENSEARCH:
# NOTE: The number of data copies, including the primary (not a
# replica) copy, must be divisible by the number of AZs.
if MULTI_TENANT:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
)
number_of_shards = 324
number_of_replicas = 2
else:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
)
number_of_shards = 3
number_of_replicas = 2
else:
return DocumentSchema.get_index_settings()
number_of_shards = 1
number_of_replicas = 1
if OPENSEARCH_INDEX_NUM_SHARDS is not None:
number_of_shards = OPENSEARCH_INDEX_NUM_SHARDS
if OPENSEARCH_INDEX_NUM_REPLICAS is not None:
number_of_replicas = OPENSEARCH_INDEX_NUM_REPLICAS
return {
"index": {
"number_of_shards": number_of_shards,
"number_of_replicas": number_of_replicas,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}

View File

@@ -7,16 +7,28 @@ from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import ASSUMED_DOCUMENT_AGE_DAYS
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
)
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.constants import (
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
)
from onyx.document_index.opensearch.constants import (
HYBRID_SEARCH_NORMALIZATION_PIPELINE,
)
from onyx.document_index.opensearch.constants import (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION,
)
from onyx.document_index.opensearch.constants import HybridSearchNormalizationPipeline
from onyx.document_index.opensearch.constants import HybridSearchSubqueryConfiguration
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
@@ -43,49 +55,113 @@ from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
# TODO(andrei): Turn all magic dictionaries to pydantic models.
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using min-max",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "min_max"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {"weights": HYBRID_SEARCH_NORMALIZATION_WEIGHTS},
},
def _get_hybrid_search_normalization_weights() -> list[float]:
if (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
# Since the titles are included in the contents, the embedding matches
# are heavily downweighted as they act as a boost rather than an
# independent scoring component.
search_title_vector_weight = 0.1
search_content_vector_weight = 0.45
# Single keyword weight for both title and content (merged from former
# title keyword + content keyword).
search_keyword_weight = 0.45
# NOTE: It is critical that the order of these weights matches the order
# of the sub-queries in the hybrid search.
hybrid_search_normalization_weights = [
search_title_vector_weight,
search_content_vector_weight,
search_keyword_weight,
]
elif (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
search_content_vector_weight = 0.5
# Single keyword weight for both title and content (merged from former
# title keyword + content keyword).
search_keyword_weight = 0.5
# NOTE: It is critical that the order of these weights matches the order
# of the sub-queries in the hybrid search.
hybrid_search_normalization_weights = [
search_content_vector_weight,
search_keyword_weight,
]
else:
raise ValueError(
f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}."
)
assert (
sum(hybrid_search_normalization_weights) == 1.0
), "Bug: Hybrid search normalization weights do not sum to 1.0."
return hybrid_search_normalization_weights
def get_min_max_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
min_max_normalization_pipeline_name = "normalization_pipeline_min_max"
min_max_normalization_pipeline_config: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using min-max",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "min_max"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {
"weights": _get_hybrid_search_normalization_weights()
},
},
}
}
}
],
}
],
}
return min_max_normalization_pipeline_name, min_max_normalization_pipeline_config
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using z-score",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "z_score"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {"weights": HYBRID_SEARCH_NORMALIZATION_WEIGHTS},
},
def get_zscore_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
zscore_normalization_pipeline_name = "normalization_pipeline_zscore"
zscore_normalization_pipeline_config: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using z-score",
"phase_results_processors": [
{
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
"normalization-processor": {
"normalization": {"technique": "z_score"},
"combination": {
"technique": "arithmetic_mean",
"parameters": {
"weights": _get_hybrid_search_normalization_weights()
},
},
}
}
}
],
}
],
}
return zscore_normalization_pipeline_name, zscore_normalization_pipeline_config
# By default OpenSearch will only return a maximum of this many results in a
# given search. This value is configurable in the index settings.
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
# that the document was last updated this many days ago for the purpose of time
# cutoff filtering during retrieval.
ASSUMED_DOCUMENT_AGE_DAYS = 90
def get_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
if (
HYBRID_SEARCH_NORMALIZATION_PIPELINE
is HybridSearchNormalizationPipeline.MIN_MAX
):
return get_min_max_normalization_pipeline_name_and_config()
elif (
HYBRID_SEARCH_NORMALIZATION_PIPELINE is HybridSearchNormalizationPipeline.ZSCORE
):
return get_zscore_normalization_pipeline_name_and_config()
else:
raise ValueError(
f"Bug: Unhandled hybrid search normalization pipeline: {HYBRID_SEARCH_NORMALIZATION_PIPELINE}."
)
class DocumentQuery:
@@ -160,9 +236,17 @@ class DocumentQuery:
# returning some number of results less than the index max allowed
# return size.
"size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
"_source": get_full_document,
# By default exclude retrieving the vector fields in order to save
# on retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
if not get_full_document:
# If we explicitly do not want the underlying document, we will only
# retrieve IDs.
final_get_ids_query["_source"] = False
if not OPENSEARCH_PROFILING_DISABLED:
final_get_ids_query["profile"] = True
@@ -257,7 +341,7 @@ class DocumentQuery:
# TODO(andrei, yuhong): We can tune this more dynamically based on
# num_hits.
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
max_results_per_subquery = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, vector_candidates=max_results_per_subquery
@@ -281,9 +365,6 @@ class DocumentQuery:
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
match_highlights_configuration = (
DocumentQuery._get_match_highlights_configuration()
)
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
hybrid_search_query: dict[str, Any] = {
@@ -310,16 +391,183 @@ class DocumentQuery:
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
# Explain is for scoring breakdowns.
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_hybrid_search_body["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
return final_hybrid_search_body
@staticmethod
def get_keyword_search_query(
query_text: str,
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final keyword search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_text: The text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the keyword search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final keyword search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
keyword_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
keyword_search_query = (
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text, search_filters=keyword_search_filters
)
)
final_keyword_search_query: dict[str, Any] = {
"query": keyword_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
final_keyword_search_query["highlight"] = (
DocumentQuery._get_match_highlights_configuration()
)
if not OPENSEARCH_PROFILING_DISABLED:
final_keyword_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_keyword_search_query["explain"] = True
return final_keyword_search_query
@staticmethod
def get_semantic_search_query(
query_embedding: list[float],
num_hits: int,
tenant_state: TenantState,
index_filters: IndexFilters,
include_hidden: bool,
) -> dict[str, Any]:
"""Returns a final semantic search query.
This query can be directly supplied to the OpenSearch client.
Args:
query_embedding: The vector embedding of the text to query for.
num_hits: The final number of hits to return.
tenant_state: Tenant state containing the tenant ID.
index_filters: Filters for the semantic search query.
include_hidden: Whether to include hidden documents.
Returns:
A dictionary representing the final semantic search query.
"""
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
semantic_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
include_hidden=include_hidden,
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
# now. This should not cause any issues but it can introduce
# redundant filters in queries that may affect performance.
access_control_list=index_filters.access_control_list,
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
attached_document_ids=index_filters.attached_document_ids,
hierarchy_node_ids=index_filters.hierarchy_node_ids,
)
semantic_search_query = (
DocumentQuery._get_content_vector_similarity_search_query(
query_embedding,
vector_candidates=num_hits,
search_filters=semantic_search_filters,
)
)
final_semantic_search_query: dict[str, Any] = {
"query": semantic_search_query,
"size": num_hits,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_semantic_search_query["profile"] = True
# Explain is for scoring breakdowns. Setting this significantly
# increases query latency.
if OPENSEARCH_EXPLAIN_ENABLED:
final_semantic_search_query["explain"] = True
return final_semantic_search_query
@staticmethod
def get_random_search_query(
tenant_state: TenantState,
@@ -371,6 +619,11 @@ class DocumentQuery:
},
"size": num_to_retrieve,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
# Exclude retrieving the vector fields in order to save on
# retrieval cost as we don't need them upstream.
"_source": {
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
},
}
if not OPENSEARCH_PROFILING_DISABLED:
final_random_search_query["profile"] = True
@@ -385,7 +638,7 @@ class DocumentQuery:
# search. This is higher than the number of results because the scoring
# is hybrid. For a detailed breakdown, see where the default value is
# set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -395,20 +648,18 @@ class DocumentQuery:
The return of this function is not sufficient to be directly supplied to
the OpenSearch client. See get_hybrid_search_query.
Matches:
- Title vector
- Content vector
- Keyword (title + content, match and phrase)
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
pipeline.
The exact subqueries executed depend on the
HYBRID_SEARCH_SUBQUERY_CONFIGURATION setting.
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
in a single hybrid query. Source:
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
NOTE: Each query is independent during the search phase, there is no
NOTE: Each query is independent during the search phase; there is no
backfilling of scores for missing query components. What this means is
that if a document was a good vector match but did not show up for
keyword, it gets a score of 0 for the keyword component of the hybrid
@@ -437,74 +688,133 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, content vector, keyword (title + content).
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 3. Keyword (title + content) match and phrase search.
{
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 0.2,
}
}
},
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
"operator": "or",
"boost": 1.0,
}
}
},
{
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 1.5,
}
}
},
]
}
},
]
# pipeline weights.
if (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
return [
DocumentQuery._get_title_vector_similarity_search_query(
query_vector, vector_candidates
),
DocumentQuery._get_content_vector_similarity_search_query(
query_vector, vector_candidates
),
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text
),
]
elif (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
):
return [
DocumentQuery._get_content_vector_similarity_search_query(
query_vector, vector_candidates
),
DocumentQuery._get_title_content_combined_keyword_search_query(
query_text
),
]
else:
raise ValueError(
f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}"
)
return hybrid_search_queries
@staticmethod
def _get_title_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
) -> dict[str, Any]:
return {
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
}
@staticmethod
def _get_content_vector_similarity_search_query(
query_vector: list[float],
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
query = {
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
}
if search_filters is not None:
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
"bool": {"filter": search_filters}
}
return query
@staticmethod
def _get_title_content_combined_keyword_search_query(
query_text: str,
search_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
query = {
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 0.2,
}
}
},
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
"operator": "or",
"boost": 1.0,
}
}
},
{
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 1.5,
}
}
},
],
# Ensure at least one term from the query is present in the
# document. This defaults to 1, unless a filter or must clause
# is supplied, in which case it defaults to 0.
"minimum_should_match": 1,
}
}
if search_filters is not None:
query["bool"]["filter"] = search_filters
return query
@staticmethod
def _get_search_filters(

View File

@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)
@@ -88,6 +89,7 @@ class OnyxErrorCode(Enum):
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:

View File

@@ -88,9 +88,13 @@ def summarize_image_with_error_handling(
try:
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
except UnsupportedImageFormatError:
magic_hex = image_data[:8].hex() if image_data else "empty"
logger.info(
"Skipping image summarization due to unsupported MIME type for %s",
"Skipping image summarization due to unsupported MIME type "
"for %s (magic_bytes=%s, size=%d bytes)",
context_name,
magic_hex,
len(image_data),
)
return None
@@ -134,9 +138,23 @@ def _summarize_image(
return summary
except Exception as e:
error_msg = f"Summarization failed. Messages: {messages}"
error_msg = error_msg[:1024]
raise ValueError(error_msg) from e
# Extract structured details from LiteLLM exceptions when available,
# rather than dumping the full messages payload (which contains base64
# image data and produces enormous, unreadable error logs).
str_e = str(e)
if len(str_e) > 512:
str_e = str_e[:512] + "... (truncated)"
parts = [f"Summarization failed: {type(e).__name__}: {str_e}"]
status_code = getattr(e, "status_code", None)
llm_provider = getattr(e, "llm_provider", None)
model = getattr(e, "model", None)
if status_code is not None:
parts.append(f"status_code={status_code}")
if llm_provider is not None:
parts.append(f"llm_provider={llm_provider}")
if model is not None:
parts.append(f"model={model}")
raise ValueError(" | ".join(parts)) from e
def _encode_image_for_llm_prompt(image_data: bytes) -> str:

View File

@@ -0,0 +1,387 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (spec.response_model)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
"""No active hook configured for this hook point."""
class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
) -> BaseModel | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
hook_point = hook.hook_point # extract before HTTP call per design intent
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against the spec's response_model.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: BaseModel | None = None
if outcome.is_success and outcome.response_payload is not None:
spec = get_hook_point_spec(hook_point)
try:
validated_model = spec.response_model.model_validate(
outcome.response_payload
)
except ValidationError as e:
msg = f"Hook response failed validation against {spec.response_model.__name__}: {e}"
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -0,0 +1,121 @@
from datetime import datetime
from enum import Enum
from typing import Annotated
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import SecretStr
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------
class HookCreateRequest(BaseModel):
name: str = Field(min_length=1)
hook_point: HookPoint
endpoint_url: str = Field(min_length=1)
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None, uses HookPointSpec default
@field_validator("name", "endpoint_url")
@classmethod
def no_whitespace_only(cls, v: str) -> str:
if not v.strip():
raise ValueError("cannot be whitespace-only.")
return v
class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None
timeout_seconds: float | None = Field(default=None, gt=0)
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
if not self.model_fields_set:
raise ValueError("At least one field must be provided for an update.")
if "name" in self.model_fields_set and not (self.name or "").strip():
raise ValueError("name cannot be cleared.")
if (
"endpoint_url" in self.model_fields_set
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
raise ValueError(
"fail_strategy cannot be null; omit the field to leave it unchanged."
)
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
raise ValueError(
"timeout_seconds cannot be null; omit the field to leave it unchanged."
)
return self
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class HookPointMetaResponse(BaseModel):
hook_point: HookPoint
display_name: str
description: str
docs_url: str | None
input_schema: dict[str, Any]
output_schema: dict[str, Any]
default_timeout_seconds: float
default_fail_strategy: HookFailStrategy
fail_hard_description: str
class HookResponse(BaseModel):
id: int
name: str
hook_point: HookPoint
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
is_reachable: bool | None
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateStatus(str, Enum):
passed = "passed" # server responded (any status except 401/403)
auth_failed = "auth_failed" # server responded with 401 or 403
timeout = (
"timeout" # TCP connected, but read/write timed out (server exists but slow)
)
cannot_connect = "cannot_connect" # could not connect to the server
class HookValidateResponse(BaseModel):
status: HookValidateStatus
error_message: str | None = None
class HookExecutionRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime

View File

View File

@@ -0,0 +1,74 @@
from typing import Any
from typing import ClassVar
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
_REQUIRED_ATTRS = (
"hook_point",
"display_name",
"description",
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
"payload_model",
"response_model",
)
class HookPointSpec:
"""Static metadata and contract for a pipeline hook point.
Each concrete subclass represents exactly one hook point and is instantiated
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
get_hook_point_spec() or get_all_specs() from the registry over direct
instantiation.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Subclasses must define all attributes as class-level constants.
payload_model and response_model must be Pydantic BaseModel subclasses;
input_schema and output_schema are derived from them automatically.
"""
hook_point: HookPoint
display_name: str
description: str
default_timeout_seconds: float
fail_hard_description: str
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
payload_model: ClassVar[type[BaseModel]]
response_model: ClassVar[type[BaseModel]]
# Computed once at class definition time from payload_model / response_model.
input_schema: ClassVar[dict[str, Any]]
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
for attr in ("payload_model", "response_model"):
val = getattr(cls, attr, None)
if val is None or not (
isinstance(val, type) and issubclass(val, BaseModel)
):
raise TypeError(
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
)
cls.input_schema = cls.payload_model.model_json_schema()
cls.output_schema = cls.response_model.model_json_schema()

View File

@@ -0,0 +1,31 @@
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
# TODO(@Bo-Onyx): define payload and response fields
class DocumentIngestionPayload(BaseModel):
pass
class DocumentIngestionResponse(BaseModel):
pass
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
"""
hook_point = HookPoint.DOCUMENT_INGESTION
display_name = "Document Ingestion"
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -0,0 +1,70 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
query: str = Field(description="The raw query string exactly as the user typed it.")
user_email: str | None = Field(
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
)
class QueryProcessingResponse(BaseModel):
# Intentionally permissive — customer endpoints may return extra fields.
query: str | None = Field(
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, whitespace-only, or absent = reject the query."
),
)
rejection_message: str | None = Field(
default=None,
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
)
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
Call site: inside handle_stream_message_objects() in
backend/onyx/chat/process_message.py, immediately after message_text is
assigned from the request and before create_new_chat_message() saves it.
This is the earliest possible point in the query pipeline:
- Raw query — unmodified, exactly as the user typed it
- No side effects yet — message has not been saved to DB
- User identity is available for user-specific logic
Supported use cases:
- Query rejection: block queries based on content or user context
- Query rewriting: normalize, expand, or modify the query
- PII removal: scrub sensitive data before the LLM sees it
- Access control: reject queries from certain users or groups
- Query auditing: log or track queries based on business rules
"""
hook_point = HookPoint.QUERY_PROCESSING
display_name = "Query Processing"
description = (
"Runs on every user query before it enters the pipeline. "
"Allows rewriting, filtering, or rejecting queries."
)
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
fail_hard_description = (
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -0,0 +1,45 @@
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
from onyx.hooks.points.query_processing import QueryProcessingSpec
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
_REGISTRY: dict[HookPoint, HookPointSpec] = {
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
}
def validate_registry() -> None:
"""Assert that every HookPoint enum value has a registered spec.
Call once at application startup (e.g. from the FastAPI lifespan hook).
Raises RuntimeError if any hook point is missing a spec.
"""
missing = set(HookPoint) - set(_REGISTRY)
if missing:
raise RuntimeError(
f"Hook point(s) have no registered spec: {missing}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
"""Returns the spec for a given hook point.
Raises ValueError if the hook point has no registered spec — this is a
programmer error; every HookPoint enum value must have a corresponding spec
in _REGISTRY.
"""
try:
return _REGISTRY[hook_point]
except KeyError:
raise ValueError(
f"No spec registered for hook point {hook_point!r}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_all_specs() -> list[HookPointSpec]:
"""Returns the specs for all registered hook points."""
return list(_REGISTRY.values())

View File

@@ -0,0 +1,5 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -395,6 +395,12 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
llm = get_default_llm_with_vision()
if not llm:
if get_image_extraction_and_analysis_enabled():
logger.warning(
"Image analysis is enabled but no vision-capable LLM is "
"available — images will not be summarized. Configure a "
"vision model in the admin LLM settings."
)
# Even without LLM, we still convert to IndexingDocument with base Sections
return [
IndexingDocument(

View File

@@ -168,10 +168,23 @@ def get_default_llm_with_vision(
if model_supports_image_input(
default_model.name, default_model.llm_provider.provider
):
logger.info(
"Using default vision model: %s (provider=%s)",
default_model.name,
default_model.llm_provider.provider,
)
return create_vision_llm(
LLMProviderView.from_model(default_model.llm_provider),
default_model.name,
)
else:
logger.warning(
"Default vision model %s (provider=%s) does not support "
"image input — falling back to searching all providers",
default_model.name,
default_model.llm_provider.provider,
)
# Fall back to searching all providers
models = fetch_existing_models(
db_session=db_session,
@@ -179,6 +192,10 @@ def get_default_llm_with_vision(
)
if not models:
logger.warning(
"No LLM models with VISION or CHAT flow type found — "
"image summarization will be disabled"
)
return None
for model in models:
@@ -200,11 +217,25 @@ def get_default_llm_with_vision(
for model in sorted_models:
if model_supports_image_input(model.name, model.llm_provider.provider):
logger.info(
"Using fallback vision model: %s (provider=%s)",
model.name,
model.llm_provider.provider,
)
return create_vision_llm(
provider_map[model.llm_provider_id],
model.name,
)
checked_models = [
f"{m.name} (provider={m.llm_provider.provider})" for m in sorted_models
]
logger.warning(
"No vision-capable model found among %d candidates: %s"
"image summarization will be disabled",
len(sorted_models),
", ".join(checked_models),
)
return None

View File

@@ -530,6 +530,11 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:
optional_kwargs["tool_choice"] = tool_choice
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
@@ -538,7 +543,6 @@ class LitellmLLM(LLM):
custom_llm_provider=self._custom_llm_provider or None,
messages=messages,
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,

View File

@@ -219,13 +219,26 @@ def litellm_exception_to_error_msg(
"ratelimiterror"
):
upstream_detail = upstream_detail.split(":", 1)[1].strip()
error_msg = (
f"{provider_name} rate limit: {upstream_detail}"
if upstream_detail
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
)
error_code = "RATE_LIMIT"
is_retryable = True
upstream_detail_lower = upstream_detail.lower()
if (
"insufficient_quota" in upstream_detail_lower
or "exceeded your current quota" in upstream_detail_lower
):
error_msg = (
f"{provider_name} quota exceeded: {upstream_detail}"
if upstream_detail
else f"{provider_name} quota exceeded: Verify billing and quota for this API key."
)
error_code = "BUDGET_EXCEEDED"
is_retryable = False
else:
error_msg = (
f"{provider_name} rate limit: {upstream_detail}"
if upstream_detail
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
)
error_code = "RATE_LIMIT"
is_retryable = True
elif isinstance(core_exception, ServiceUnavailableError):
provider_name = (
llm.config.model_provider

View File

@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.file_store.file_store import get_default_file_store
from onyx.hooks.registry import validate_registry
from onyx.server.api_key.api import router as api_key_router
from onyx.server.auth_check import check_router_auth
from onyx.server.documents.cc_pair import router as cc_pair_router
@@ -76,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -308,6 +310,7 @@ def validate_no_vector_db_settings() -> None:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
validate_registry()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
@@ -451,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -479,7 +479,9 @@ def is_zip_file(file: UploadFile) -> bool:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
files: list[UploadFile],
file_origin: FileOrigin = FileOrigin.CONNECTOR,
unzip: bool = True,
) -> FileUploadResponse:
# Skip directories and known macOS metadata entries
@@ -502,31 +504,46 @@ def upload_files(
if seen_zip:
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
seen_zip = True
# Validate the zip by opening it (catches corrupt/non-zip files)
with zipfile.ZipFile(file.file, "r") as zf:
zip_metadata_file_id = save_zip_metadata_to_file_store(
zf, file_store
)
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_id = file_store.save_file(
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=file_origin,
file_type=mime_type,
if unzip:
zip_metadata_file_id = save_zip_metadata_to_file_store(
zf, file_store
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_id = file_store.save_file(
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=file_origin,
file_type=mime_type,
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
continue
# Store the zip as-is (unzip=False)
file.file.seek(0)
file_id = file_store.save_file(
content=file.file,
display_name=file.filename,
file_origin=file_origin,
file_type=file.content_type or "application/zip",
)
deduped_file_paths.append(file_id)
deduped_file_names.append(file.filename)
continue
# Since we can't render docx files in the UI,
@@ -613,9 +630,10 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
def upload_files_api(
files: list[UploadFile],
unzip: bool = True,
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files, FileOrigin.OTHER)
return upload_files(files, FileOrigin.OTHER, unzip=unzip)
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)

View File

@@ -408,7 +408,7 @@ class FailedConnectorIndexingStatus(BaseModel):
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""
cc_pair_id: int
name: str | None
name: str
error_msg: str | None
is_deletable: bool
connector_id: int
@@ -422,7 +422,7 @@ class ConnectorStatus(BaseModel):
"""
cc_pair_id: int
name: str | None
name: str
connector: ConnectorSnapshot
credential: CredentialSnapshot
access_type: AccessType
@@ -453,7 +453,7 @@ class DocsCountOperator(str, Enum):
class ConnectorIndexingStatusLite(BaseModel):
cc_pair_id: int
name: str | None
name: str
source: DocumentSource
access_type: AccessType
cc_pair_status: ConnectorCredentialPairStatus
@@ -488,7 +488,7 @@ class ConnectorCredentialPairIdentifier(BaseModel):
class ConnectorCredentialPairMetadata(BaseModel):
name: str | None = None
name: str
access_type: AccessType
auto_sync_options: dict[str, Any] | None = None
groups: list[int] = Field(default_factory=list)
@@ -501,7 +501,7 @@ class CCStatusUpdateRequest(BaseModel):
class ConnectorCredentialPairDescriptor(BaseModel):
id: int
name: str | None = None
name: str
connector: ConnectorSnapshot
credential: CredentialSnapshot
access_type: AccessType
@@ -511,7 +511,7 @@ class CCPairSummary(BaseModel):
"""Simplified connector-credential pair information with just essential data"""
id: int
name: str | None
name: str
source: DocumentSource
access_type: AccessType

View File

@@ -15,7 +15,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.1.5",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",
@@ -1711,9 +1711,9 @@
}
},
"node_modules/@next/env": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.5.tgz",
"integrity": "sha512-CRSCPJiSZoi4Pn69RYBDI9R7YK2g59vLexPQFXY0eyw+ILevIenCywzg+DqmlBik9zszEnw2HLFOUlLAcJbL7g==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
"license": "MIT"
},
"node_modules/@next/eslint-plugin-next": {
@@ -1727,9 +1727,9 @@
}
},
"node_modules/@next/swc-darwin-arm64": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.5.tgz",
"integrity": "sha512-eK7Wdm3Hjy/SCL7TevlH0C9chrpeOYWx2iR7guJDaz4zEQKWcS1IMVfMb9UKBFMg1XgzcPTYPIp1Vcpukkjg6Q==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
"cpu": [
"arm64"
],
@@ -1743,9 +1743,9 @@
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.5.tgz",
"integrity": "sha512-foQscSHD1dCuxBmGkbIr6ScAUF6pRoDZP6czajyvmXPAOFNnQUJu2Os1SGELODjKp/ULa4fulnBWoHV3XdPLfA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
"cpu": [
"x64"
],
@@ -1759,9 +1759,9 @@
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.5.tgz",
"integrity": "sha512-qNIb42o3C02ccIeSeKjacF3HXotGsxh/FMk/rSRmCzOVMtoWH88odn2uZqF8RLsSUWHcAqTgYmPD3pZ03L9ZAA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
"cpu": [
"arm64"
],
@@ -1775,9 +1775,9 @@
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.5.tgz",
"integrity": "sha512-U+kBxGUY1xMAzDTXmuVMfhaWUZQAwzRaHJ/I6ihtR5SbTVUEaDRiEU9YMjy1obBWpdOBuk1bcm+tsmifYSygfw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
"cpu": [
"arm64"
],
@@ -1791,9 +1791,9 @@
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.5.tgz",
"integrity": "sha512-gq2UtoCpN7Ke/7tKaU7i/1L7eFLfhMbXjNghSv0MVGF1dmuoaPeEVDvkDuO/9LVa44h5gqpWeJ4mRRznjDv7LA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
"cpu": [
"x64"
],
@@ -1807,9 +1807,9 @@
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.5.tgz",
"integrity": "sha512-bQWSE729PbXT6mMklWLf8dotislPle2L70E9q6iwETYEOt092GDn0c+TTNj26AjmeceSsC4ndyGsK5nKqHYXjQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
"cpu": [
"x64"
],
@@ -1823,9 +1823,9 @@
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.5.tgz",
"integrity": "sha512-LZli0anutkIllMtTAWZlDqdfvjWX/ch8AFK5WgkNTvaqwlouiD1oHM+WW8RXMiL0+vAkAJyAGEzPPjO+hnrSNQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
"cpu": [
"arm64"
],
@@ -1839,9 +1839,9 @@
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.5.tgz",
"integrity": "sha512-7is37HJTNQGhjPpQbkKjKEboHYQnCgpVt/4rBrrln0D9nderNxZ8ZWs8w1fAtzUx7wEyYjQ+/13myFgFj6K2Ng==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
"cpu": [
"x64"
],
@@ -4971,12 +4971,15 @@
"license": "MIT"
},
"node_modules/baseline-browser-mapping": {
"version": "2.9.17",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.17.tgz",
"integrity": "sha512-agD0MgJFUP/4nvjqzIB29zRPUuCF7Ge6mEv9s8dHrtYD7QWXRcx75rOADE/d5ah1NI+0vkDl0yorDd5U852IQQ==",
"version": "2.10.8",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
"license": "Apache-2.0",
"bin": {
"baseline-browser-mapping": "dist/cli.js"
"baseline-browser-mapping": "dist/cli.cjs"
},
"engines": {
"node": ">=6.0.0"
}
},
"node_modules/body-parser": {
@@ -8975,14 +8978,14 @@
}
},
"node_modules/next": {
"version": "16.1.5",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.5.tgz",
"integrity": "sha512-f+wE+NSbiQgh3DSAlTaw2FwY5yGdVViAtp8TotNQj4kk4Q8Bh1sC/aL9aH+Rg1YAVn18OYXsRDT7U/079jgP7w==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
"license": "MIT",
"dependencies": {
"@next/env": "16.1.5",
"@next/env": "16.1.7",
"@swc/helpers": "0.5.15",
"baseline-browser-mapping": "^2.8.3",
"baseline-browser-mapping": "^2.9.19",
"caniuse-lite": "^1.0.30001579",
"postcss": "8.4.31",
"styled-jsx": "5.1.6"
@@ -8994,14 +8997,14 @@
"node": ">=20.9.0"
},
"optionalDependencies": {
"@next/swc-darwin-arm64": "16.1.5",
"@next/swc-darwin-x64": "16.1.5",
"@next/swc-linux-arm64-gnu": "16.1.5",
"@next/swc-linux-arm64-musl": "16.1.5",
"@next/swc-linux-x64-gnu": "16.1.5",
"@next/swc-linux-x64-musl": "16.1.5",
"@next/swc-win32-arm64-msvc": "16.1.5",
"@next/swc-win32-x64-msvc": "16.1.5",
"@next/swc-darwin-arm64": "16.1.7",
"@next/swc-darwin-x64": "16.1.7",
"@next/swc-linux-arm64-gnu": "16.1.7",
"@next/swc-linux-arm64-musl": "16.1.7",
"@next/swc-linux-x64-gnu": "16.1.7",
"@next/swc-linux-x64-musl": "16.1.7",
"@next/swc-win32-arm64-msvc": "16.1.7",
"@next/swc-win32-x64-msvc": "16.1.7",
"sharp": "^0.34.4"
},
"peerDependencies": {

View File

@@ -16,7 +16,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.1.5",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",

View File

@@ -0,0 +1,453 @@
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import create_hook__no_commit
from onyx.db.hook import delete_hook__no_commit
from onyx.db.hook import get_hook_by_id
from onyx.db.hook import get_hook_execution_logs
from onyx.db.hook import get_hooks
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.api_dependencies import require_hook_enabled
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookExecutionRecord
from onyx.hooks.models import HookPointMetaResponse
from onyx.hooks.models import HookResponse
from onyx.hooks.models import HookUpdateRequest
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
logger = setup_logger()
# ---------------------------------------------------------------------------
# SSRF protection
# ---------------------------------------------------------------------------
def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
return HookResponse(
id=hook.id,
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
is_reachable=hook.is_reachable,
creator_email=(
creator_email
if creator_email is not None
else (hook.creator.email if hook.creator else None)
),
created_at=hook.created_at,
updated_at=hook.updated_at,
)
def _get_hook_or_404(
db_session: Session,
hook_id: int,
include_creator: bool = False,
) -> Hook:
hook = get_hook_by_id(
db_session=db_session,
hook_id=hook_id,
include_creator=include_creator,
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
return hook
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
"""Raise an appropriate OnyxError for a non-passed validation result."""
if validation.status == HookValidateStatus.auth_failed:
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
if validation.status == HookValidateStatus.timeout:
raise OnyxError(
OnyxErrorCode.GATEWAY_TIMEOUT,
f"Endpoint validation failed: {validation.error_message}",
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Endpoint validation failed: {validation.error_message}",
)
def _validate_endpoint(
endpoint_url: str,
api_key: str | None,
timeout_seconds: float,
) -> HookValidateResponse:
"""Check whether endpoint_url is reachable by sending an empty POST request.
We use POST since hook endpoints expect POST requests. The server will typically
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
the server is up and routable. A 401/403 response returns auth_failed
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
response = client.post(endpoint_url, headers=headers)
if response.status_code in (401, 403):
return HookValidateResponse(
status=HookValidateStatus.auth_failed,
error_message=f"Authentication failed (HTTP {response.status_code})",
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.timeout,
error_message="Endpoint timed out — consider increasing timeout_seconds.",
)
except Exception as exc:
logger.warning(
"Hook endpoint validation: connection error for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# ---------------------------------------------------------------------------
# Routers
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/admin/hooks")
# ---------------------------------------------------------------------------
# Hook endpoints
# ---------------------------------------------------------------------------
@router.get("/specs")
def get_hook_point_specs(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
) -> list[HookPointMetaResponse]:
return [
HookPointMetaResponse(
hook_point=spec.hook_point,
display_name=spec.display_name,
description=spec.description,
docs_url=spec.docs_url,
input_schema=spec.input_schema,
output_schema=spec.output_schema,
default_timeout_seconds=spec.default_timeout_seconds,
default_fail_strategy=spec.default_fail_strategy,
fail_hard_description=spec.fail_hard_description,
)
for spec in get_all_specs()
]
@router.get("")
def list_hooks(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookResponse]:
hooks = get_hooks(db_session=db_session, include_creator=True)
return [_hook_to_response(h) for h in hooks]
@router.post("")
def create_hook(
req: HookCreateRequest,
user: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
endpoint_url=req.endpoint_url,
api_key=api_key,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
hook = create_hook__no_commit(
db_session=db_session,
name=req.name,
hook_point=req.hook_point,
endpoint_url=req.endpoint_url,
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)
@router.get("/{hook_id}")
def get_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
return _hook_to_response(hook)
@router.patch("/{hook_id}")
def update_hook(
hook_id: int,
req: HookUpdateRequest,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
endpoint is re-validated using the effective values. For active hooks the update is
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
the update goes through regardless and is_reachable is updated to reflect the result.
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
"""
# api_key: UNSET = no change, None = clear, value = update
api_key: str | None | UnsetType
if "api_key" not in req.model_fields_set:
api_key = UNSET
elif req.api_key is None:
api_key = None
else:
api_key = req.api_key.get_secret_value()
endpoint_url_changing = "endpoint_url" in req.model_fields_set
api_key_changing = not isinstance(api_key, UnsetType)
timeout_changing = "timeout_seconds" in req.model_fields_set
validated_is_reachable: bool | None = None
if endpoint_url_changing or api_key_changing or timeout_changing:
existing = _get_hook_or_404(db_session, hook_id)
effective_url: str = (
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
)
effective_api_key: str | None = (
(api_key if not isinstance(api_key, UnsetType) else None)
if api_key_changing
else (
existing.api_key.get_value(apply_mask=False)
if existing.api_key
else None
)
)
effective_timeout: float = (
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
)
validation = _validate_endpoint(
endpoint_url=effective_url,
api_key=effective_api_key,
timeout_seconds=effective_timeout,
)
if existing.is_active and validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
validated_is_reachable = validation.status == HookValidateStatus.passed
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
name=req.name,
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
api_key=api_key,
fail_strategy=req.fail_strategy,
timeout_seconds=req.timeout_seconds,
is_reachable=validated_is_reachable,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.delete("/{hook_id}")
def delete_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> None:
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
db_session.commit()
@router.post("/{hook_id}/activate")
def activate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
# Persist is_reachable=False in a separate session so the request
# session has no commits on the failure path and the transaction
# boundary stays clean.
if hook.is_reachable is not False:
with get_session_with_current_tenant() as side_session:
update_hook__no_commit(
db_session=side_session, hook_id=hook_id, is_reachable=False
)
side_session.commit()
_raise_for_validation_failure(validation)
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=True,
is_reachable=True,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.post("/{hook_id}/validate")
def validate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookValidateResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
validation_passed = validation.status == HookValidateStatus.passed
if hook.is_reachable != validation_passed:
update_hook__no_commit(
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
)
db_session.commit()
return validation
@router.post("/{hook_id}/deactivate")
def deactivate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=False,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
# ---------------------------------------------------------------------------
# Execution log endpoints
# ---------------------------------------------------------------------------
@router.get("/{hook_id}/execution-logs")
def list_hook_execution_logs(
hook_id: int,
limit: int = Query(default=10, ge=1, le=100),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookExecutionRecord]:
_get_hook_or_404(db_session, hook_id)
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
return [
HookExecutionRecord(
error_message=log.error_message,
status_code=log.status_code,
duration_ms=log.duration_ms,
created_at=log.created_at,
)
for log in logs
]

View File

@@ -1,239 +0,0 @@
from __future__ import annotations
import re
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
def _adjust_message_text_for_agent_search_results(
adjusted_message_text: str,
final_documents: list[SavedSearchDoc], # noqa: ARG001
) -> str:
# Remove all [Q<integer>] patterns (sub-question citations)
return re.sub(r"\[Q\d+\]", "", adjusted_message_text)
def _replace_d_citations_with_links(
message_text: str, final_documents: list[SavedSearchDoc]
) -> str:
def replace_citation(match: re.Match[str]) -> str:
d_number = match.group(1)
try:
doc_index = int(d_number) - 1
if 0 <= doc_index < len(final_documents):
doc = final_documents[doc_index]
link = doc.link if doc.link else ""
return f"[[{d_number}]]({link})"
return match.group(0)
except (ValueError, IndexError):
return match.group(0)
return re.sub(r"\[D(\d+)\]", replace_citation, message_text)
def create_message_packets(
message_text: str,
final_documents: list[SavedSearchDoc] | None,
turn_index: int,
is_legacy_agentic: bool = False,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=AgentResponseStart(
final_documents=SearchDoc.from_saved_search_docs(final_documents or []),
),
)
)
adjusted_message_text = message_text
if is_legacy_agentic:
if final_documents is not None:
adjusted_message_text = _adjust_message_text_for_agent_search_results(
message_text, final_documents
)
adjusted_message_text = _replace_d_citations_with_links(
adjusted_message_text, final_documents
)
else:
adjusted_message_text = re.sub(r"\[Q\d+\]", "", message_text)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=AgentResponseDelta(
content=adjusted_message_text,
),
),
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SectionEnd(),
)
)
return packets
def create_citation_packets(
citation_info_list: list[CitationInfo], turn_index: int
) -> list[Packet]:
packets: list[Packet] = []
# Emit each citation as a separate CitationInfo packet
for citation_info in citation_info_list:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=citation_info,
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=ReasoningDelta(
reasoning=reasoning_text,
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_image_generation_packets(
images: list[GeneratedImage], turn_index: int
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=ImageGenerationToolStart(),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=ImageGenerationFinal(images=images),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_fetch_packets(
fetch_docs: list[SavedSearchDoc],
urls: list[str],
turn_index: int,
) -> list[Packet]:
packets: list[Packet] = []
# Emit start packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=OpenUrlStart(),
)
)
# Emit URLs packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=OpenUrlUrls(urls=urls),
)
)
# Emit documents packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=OpenUrlDocuments(
documents=SearchDoc.from_saved_search_docs(fetch_docs)
),
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets
def create_search_packets(
search_queries: list[str],
saved_search_docs: list[SavedSearchDoc],
is_internet_search: bool,
turn_index: int,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SearchToolStart(
is_internet_search=is_internet_search,
),
)
)
# Emit queries if present
if search_queries:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SearchToolQueriesDelta(queries=search_queries),
),
)
# Emit documents if present
if saved_search_docs:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
obj=SearchToolDocumentsDelta(
documents=SearchDoc.from_saved_search_docs(saved_search_docs)
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
return packets

View File

@@ -1,3 +1,4 @@
import hashlib
import mimetypes
from io import BytesIO
from typing import Any
@@ -83,6 +84,14 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
def __init__(self, tool_id: int, emitter: Emitter) -> None:
super().__init__(emitter=emitter)
self._id = tool_id
# Cache of (filename, content_hash) -> ci_file_id to avoid re-uploading
# the same file on every tool call iteration within the same agent session.
# Filename is included in the key so two files with identical bytes but
# different names each get their own upload slot.
# TTL assumption: code-interpreter file TTLs (typically hours) greatly
# exceed the lifetime of a single agent session (at most MAX_LLM_CYCLES
# iterations, typically a few minutes), so stale-ID eviction is not needed.
self._uploaded_file_cache: dict[tuple[str, str], str] = {}
@property
def id(self) -> int:
@@ -182,8 +191,13 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
content_hash = hashlib.sha256(chat_file.content).hexdigest()
cache_key = (file_name, content_hash)
ci_file_id = self._uploaded_file_cache.get(cache_key)
if ci_file_id is None:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
self._uploaded_file_cache[cache_key] = ci_file_id
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
@@ -299,14 +313,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Note: staged input files are intentionally not deleted here because
# _uploaded_file_cache reuses their file_ids across iterations. They are
# orphaned when the session ends, but the code interpreter cleans up
# stale files on its own TTL.
# Emit file_ids once files are processed
if generated_file_ids:

View File

@@ -74,7 +74,7 @@ def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
"""helper function to return an id given a string input"""
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
hash_obj = hashlib.md5(hash_input.encode("utf-8"), usedforsecurity=False)
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case

View File

@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
return validated_ip, hostname, port
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
def validate_outbound_http_url(
url: str,
*,
allow_private_network: bool = False,
https_only: bool = False,
) -> str:
"""
Validate a URL that will be used by backend outbound HTTP calls.
Args:
url: The URL to validate.
allow_private_network: If True, skip private/reserved IP checks.
https_only: If True, reject http:// URLs (only https:// is allowed).
Returns:
A normalized URL string with surrounding whitespace removed.
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
parsed = urlparse(normalized_url)
if parsed.scheme not in ("http", "https"):
if https_only:
if parsed.scheme != "https":
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
)
elif parsed.scheme not in ("http", "https"):
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
)

View File

@@ -698,7 +698,7 @@ py-key-value-aio==0.4.4
# via fastmcp
pyairtable==3.0.1
# via onyx
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.8.0
pypdf==6.9.1
# via
# onyx
# unstructured-client

View File

@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.0
onyx-devtools==0.7.1
# via onyx
openai==2.14.0
# via
@@ -326,7 +326,7 @@ pure-eval==0.2.3
# via stack-data
py==1.11.0
# via retry
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa

View File

@@ -195,7 +195,7 @@ propcache==0.4.1
# yarl
py==1.11.0
# via retry
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa

View File

@@ -285,7 +285,7 @@ psutil==7.1.3
# via accelerate
py==1.11.0
# via retry
pyasn1==0.6.2
pyasn1==0.6.3
# via
# pyasn1-modules
# rsa

View File

@@ -0,0 +1,471 @@
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import TypedDict
from typing import TypeGuard
import aiohttp
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "http://localhost:3000"
INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
@dataclass(frozen=True)
class QuestionRecord:
question_id: str
question: str
@dataclass(frozen=True)
class AnswerRecord:
question_id: str
answer: str
document_ids: list[str]
@dataclass(frozen=True)
class FailedQuestionRecord:
question_id: str
error: str
class Citation(TypedDict, total=False):
citation_number: int
document_id: str
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Submit questions to Onyx chat with internal search forced and write "
"answers to a JSONL file."
)
)
parser.add_argument(
"--questions-file",
type=Path,
required=True,
help="Path to the input questions JSONL file.",
)
parser.add_argument(
"--output-file",
type=Path,
required=True,
help="Path to the output answers JSONL file.",
)
parser.add_argument(
"--api-key",
type=str,
required=True,
help="API key used to authenticate against Onyx.",
)
parser.add_argument(
"--api-base",
type=str,
default=DEFAULT_API_BASE,
help=(
"Frontend base URL for Onyx. If `/api` is omitted, it will be added "
f"automatically. Default: {DEFAULT_API_BASE}"
),
)
parser.add_argument(
"--parallelism",
type=int,
default=1,
help="Number of questions to process in parallel. Default: 1.",
)
parser.add_argument(
"--max-questions",
type=int,
default=None,
help="Optional cap on how many questions to process. Defaults to all.",
)
return parser.parse_args()
def normalize_api_base(api_base: str) -> str:
normalized = api_base.rstrip("/")
if normalized.endswith("/api"):
return normalized
return f"{normalized}/api"
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
questions: list[QuestionRecord] = []
with questions_file.open("r", encoding="utf-8") as file:
for line_number, line in enumerate(file, start=1):
stripped_line = line.strip()
if not stripped_line:
continue
try:
payload = json.loads(stripped_line)
except json.JSONDecodeError as exc:
raise ValueError(
f"Invalid JSON on line {line_number} of {questions_file}"
) from exc
question_id = payload.get("question_id")
question = payload.get("question")
if not isinstance(question_id, str) or not question_id:
raise ValueError(
f"Line {line_number} is missing a non-empty `question_id`."
)
if not isinstance(question, str) or not question:
raise ValueError(
f"Line {line_number} is missing a non-empty `question`."
)
questions.append(QuestionRecord(question_id=question_id, question=question))
return questions
async def read_json_response(
response: aiohttp.ClientResponse,
) -> dict[str, Any] | list[dict[str, Any]]:
response_text = await response.text()
if response.status >= 400:
raise RuntimeError(
f"Request to {response.url} failed with {response.status}: {response_text}"
)
try:
payload = json.loads(response_text)
except json.JSONDecodeError as exc:
raise RuntimeError(
f"Request to {response.url} returned non-JSON content: {response_text}"
) from exc
if not isinstance(payload, (dict, list)):
raise RuntimeError(
f"Unexpected response payload type from {response.url}: {type(payload)}"
)
return payload
async def request_json_with_retries(
session: aiohttp.ClientSession,
method: str,
url: str,
headers: dict[str, str],
json_payload: dict[str, Any] | None = None,
) -> dict[str, Any] | list[dict[str, Any]]:
backoff_seconds = 1.0
for attempt in range(1, MAX_REQUEST_ATTEMPTS + 1):
try:
async with session.request(
method=method,
url=url,
headers=headers,
json=json_payload,
) as response:
if (
response.status in RETRIABLE_STATUS_CODES
and attempt < MAX_REQUEST_ATTEMPTS
):
response_text = await response.text()
logger.warning(
"Retryable response from %s on attempt %s/%s: %s %s",
url,
attempt,
MAX_REQUEST_ATTEMPTS,
response.status,
response_text,
)
await asyncio.sleep(backoff_seconds)
backoff_seconds *= 2
continue
return await read_json_response(response)
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
if attempt == MAX_REQUEST_ATTEMPTS:
raise RuntimeError(
f"Request to {url} failed after {MAX_REQUEST_ATTEMPTS} attempts."
) from exc
logger.warning(
"Request to %s failed on attempt %s/%s: %s",
url,
attempt,
MAX_REQUEST_ATTEMPTS,
exc,
)
await asyncio.sleep(backoff_seconds)
backoff_seconds *= 2
raise RuntimeError(f"Request to {url} failed unexpectedly.")
def extract_document_ids(citation_info: object) -> list[str]:
if not isinstance(citation_info, list):
return []
sorted_citations = sorted(
(citation for citation in citation_info if _is_valid_citation(citation)),
key=_citation_sort_key,
)
document_ids: list[str] = []
seen_document_ids: set[str] = set()
for citation in sorted_citations:
document_id = citation["document_id"]
if document_id not in seen_document_ids:
seen_document_ids.add(document_id)
document_ids.append(document_id)
return document_ids
def _is_valid_citation(citation: object) -> TypeGuard[Citation]:
return (
isinstance(citation, dict)
and isinstance(citation.get("document_id"), str)
and bool(citation["document_id"])
)
def _citation_sort_key(citation: Citation) -> int:
citation_number = citation.get("citation_number")
if isinstance(citation_number, int):
return citation_number
return sys.maxsize
async def fetch_internal_search_tool_id(
session: aiohttp.ClientSession,
api_base: str,
headers: dict[str, str],
) -> int:
payload = await request_json_with_retries(
session=session,
method="GET",
url=f"{api_base}/tool",
headers=headers,
)
if not isinstance(payload, list):
raise RuntimeError("Expected `/tool` to return a list.")
for tool in payload:
if not isinstance(tool, dict):
continue
if tool.get("in_code_tool_id") == INTERNAL_SEARCH_IN_CODE_TOOL_ID:
tool_id = tool.get("id")
if isinstance(tool_id, int):
return tool_id
for tool in payload:
if not isinstance(tool, dict):
continue
if tool.get("name") == INTERNAL_SEARCH_TOOL_NAME:
tool_id = tool.get("id")
if isinstance(tool_id, int):
return tool_id
raise RuntimeError(
"Could not find the internal search tool in `/tool`. "
"Make sure SearchTool is available for this environment."
)
async def submit_question(
session: aiohttp.ClientSession,
api_base: str,
headers: dict[str, str],
internal_search_tool_id: int,
question_record: QuestionRecord,
) -> AnswerRecord:
payload = {
"message": question_record.question,
"chat_session_info": {"persona_id": 0},
"parent_message_id": None,
"file_descriptors": [],
"allowed_tool_ids": [internal_search_tool_id],
"forced_tool_id": internal_search_tool_id,
"stream": False,
}
response_payload = await request_json_with_retries(
session=session,
method="POST",
url=f"{api_base}/chat/send-chat-message",
headers=headers,
json_payload=payload,
)
if not isinstance(response_payload, dict):
raise RuntimeError(
"Expected `/chat/send-chat-message` to return an object when `stream=false`."
)
answer = response_payload.get("answer_citationless")
if not isinstance(answer, str):
answer = response_payload.get("answer")
if not isinstance(answer, str):
raise RuntimeError(
f"Response for question {question_record.question_id} is missing `answer`."
)
return AnswerRecord(
question_id=question_record.question_id,
answer=answer,
document_ids=extract_document_ids(response_payload.get("citation_info")),
)
async def generate_answers(
questions: list[QuestionRecord],
output_file: Path,
api_base: str,
api_key: str,
parallelism: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
timeout = aiohttp.ClientTimeout(
total=None,
connect=30,
sock_connect=30,
sock_read=600,
)
connector = aiohttp.TCPConnector(limit=parallelism)
output_file.parent.mkdir(parents=True, exist_ok=True)
with output_file.open("a", encoding="utf-8") as file:
async with aiohttp.ClientSession(
timeout=timeout, connector=connector
) as session:
internal_search_tool_id = await fetch_internal_search_tool_id(
session=session,
api_base=api_base,
headers=headers,
)
logger.info("Using internal search tool id %s", internal_search_tool_id)
semaphore = asyncio.Semaphore(parallelism)
progress_lock = asyncio.Lock()
write_lock = asyncio.Lock()
completed = 0
successful = 0
failed_questions: list[FailedQuestionRecord] = []
total = len(questions)
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
)
except Exception as exc:
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
return
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
successful += 1
logger.info("Processed %s/%s questions", completed, total)
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
if failed_questions:
logger.warning(
"Completed with %s failed questions and %s successful questions.",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
"Failed question %s: %s",
failed_question.question_id,
failed_question.error,
)
def main() -> None:
args = parse_args()
questions = load_questions(args.questions_file)
api_base = normalize_api_base(args.api_base)
if args.max_questions is not None:
if args.max_questions < 1:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
generate_answers(
questions=questions,
output_file=args.output_file,
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
)
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,291 @@
"""
Script to upload files from a directory as individual file connectors in Onyx.
Each file gets its own connector named after the file.
Usage:
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY --api-base http://onyxserver:3000
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY --file-glob '*.zip'
Requires:
pip install requests
"""
import argparse
import fnmatch
import os
import sys
import threading
import time
import requests
REQUEST_TIMEOUT = 900 # 15 minutes
def _elapsed_printer(label: str, stop_event: threading.Event) -> None:
"""Print a live elapsed-time counter until stop_event is set."""
start = time.monotonic()
while not stop_event.wait(timeout=1):
elapsed = int(time.monotonic() - start)
m, s = divmod(elapsed, 60)
print(f"\r {label} ... {m:02d}:{s:02d}", end="", flush=True)
elapsed = int(time.monotonic() - start)
m, s = divmod(elapsed, 60)
print(f"\r {label} ... {m:02d}:{s:02d} done")
def _timed_request(label: str, fn: object) -> requests.Response:
"""Run a request function while displaying a live elapsed timer."""
stop = threading.Event()
t = threading.Thread(target=_elapsed_printer, args=(label, stop), daemon=True)
t.start()
try:
resp = fn() # type: ignore[operator]
finally:
stop.set()
t.join()
return resp
def upload_file(
session: requests.Session, base_url: str, file_path: str
) -> dict | None:
"""Upload a single file and return the response with file_paths and file_names."""
with open(file_path, "rb") as f:
resp = _timed_request(
"Uploading",
lambda: session.post(
f"{base_url}/api/manage/admin/connector/file/upload",
files={"files": (os.path.basename(file_path), f)},
timeout=REQUEST_TIMEOUT,
),
)
if not resp.ok:
print(f" ERROR uploading: {resp.text}")
return None
return resp.json()
def create_connector(
session: requests.Session,
base_url: str,
name: str,
file_paths: list[str],
file_names: list[str],
zip_metadata_file_id: str | None,
) -> int | None:
"""Create a file connector and return its ID."""
resp = _timed_request(
"Creating connector",
lambda: session.post(
f"{base_url}/api/manage/admin/connector",
json={
"name": name,
"source": "file",
"input_type": "load_state",
"connector_specific_config": {
"file_locations": file_paths,
"file_names": file_names,
"zip_metadata_file_id": zip_metadata_file_id,
},
"refresh_freq": None,
"prune_freq": None,
"indexing_start": None,
"access_type": "public",
"groups": [],
},
timeout=REQUEST_TIMEOUT,
),
)
if not resp.ok:
print(f" ERROR creating connector: {resp.text}")
return None
return resp.json()["id"]
def create_credential(
session: requests.Session, base_url: str, name: str
) -> int | None:
"""Create a dummy credential for the file connector."""
resp = session.post(
f"{base_url}/api/manage/credential",
json={
"credential_json": {},
"admin_public": True,
"source": "file",
"curator_public": True,
"groups": [],
"name": name,
},
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
print(f" ERROR creating credential: {resp.text}")
return None
return resp.json()["id"]
def link_credential(
session: requests.Session,
base_url: str,
connector_id: int,
credential_id: int,
name: str,
) -> bool:
"""Link the connector to the credential (create CC pair)."""
resp = session.put(
f"{base_url}/api/manage/connector/{connector_id}/credential/{credential_id}",
json={
"name": name,
"access_type": "public",
"groups": [],
"auto_sync_options": None,
"processing_mode": "REGULAR",
},
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
print(f" ERROR linking credential: {resp.text}")
return False
return True
def run_connector(
session: requests.Session,
base_url: str,
connector_id: int,
credential_id: int,
) -> bool:
"""Trigger the connector to start indexing."""
resp = session.post(
f"{base_url}/api/manage/admin/connector/run-once",
json={
"connector_id": connector_id,
"credentialIds": [credential_id],
"from_beginning": False,
},
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
print(f" ERROR running connector: {resp.text}")
return False
return True
def process_file(session: requests.Session, base_url: str, file_path: str) -> bool:
"""Process a single file through the full connector creation flow."""
file_name = os.path.basename(file_path)
connector_name = file_name
print(f"Processing: {file_name}")
# Step 1: Upload
upload_resp = upload_file(session, base_url, file_path)
if not upload_resp:
return False
# Step 2: Create connector
connector_id = create_connector(
session,
base_url,
name=f"FileConnector-{connector_name}",
file_paths=upload_resp["file_paths"],
file_names=upload_resp["file_names"],
zip_metadata_file_id=upload_resp.get("zip_metadata_file_id"),
)
if connector_id is None:
return False
# Step 3: Create credential
credential_id = create_credential(session, base_url, name=connector_name)
if credential_id is None:
return False
# Step 4: Link connector to credential
if not link_credential(
session, base_url, connector_id, credential_id, connector_name
):
return False
# Step 5: Trigger indexing
if not run_connector(session, base_url, connector_id, credential_id):
return False
print(f" OK (connector_id={connector_id})")
return True
def get_authenticated_session(api_key: str) -> requests.Session:
"""Create a session authenticated with an API key."""
session = requests.Session()
session.headers.update({"Authorization": f"Bearer {api_key}"})
return session
def main() -> None:
parser = argparse.ArgumentParser(
description="Upload files as individual Onyx file connectors."
)
parser.add_argument(
"--data-dir",
required=True,
help="Directory containing files to upload.",
)
parser.add_argument(
"--api-base",
default="http://localhost:3000",
help="Base URL for the Onyx API (default: http://localhost:3000).",
)
parser.add_argument(
"--api-key",
required=True,
help="API key for authentication.",
)
parser.add_argument(
"--file-glob",
default=None,
help="Glob pattern to filter files (e.g. '*.json', '*.zip').",
)
args = parser.parse_args()
data_dir = args.data_dir
base_url = args.api_base.rstrip("/")
api_key = args.api_key
file_glob = args.file_glob
if not os.path.isdir(data_dir):
print(f"Error: {data_dir} is not a directory")
sys.exit(1)
script_path = os.path.realpath(__file__)
files = sorted(
os.path.join(data_dir, f)
for f in os.listdir(data_dir)
if os.path.isfile(os.path.join(data_dir, f))
and os.path.realpath(os.path.join(data_dir, f)) != script_path
and (file_glob is None or fnmatch.fnmatch(f, file_glob))
)
if not files:
print(f"No files found in {data_dir}")
sys.exit(1)
print(f"Found {len(files)} file(s) in {data_dir}\n")
session = get_authenticated_session(api_key)
success = 0
failed = 0
for file_path in files:
if process_file(session, base_url, file_path):
success += 1
else:
failed += 1
# Small delay to avoid overwhelming the server
time.sleep(0.5)
print(f"\nDone: {success} succeeded, {failed} failed out of {len(files)} files.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,274 @@
"""
External dependency unit tests for user file delete queue protections.
Verifies that the three mechanisms added to check_for_user_file_delete work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_DELETE_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in DELETING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_DELETE_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that delete_user_file_impl clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_delete_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_delete_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_for_user_file_delete,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file_delete,
)
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_deleting_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in DELETING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.DELETING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestDeleteQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "del_bp_user")
_create_deleting_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=USER_FILE_DELETE_MAX_QUEUE_DEPTH + 1),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestDeletePerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "del_guard_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "del_guard_set_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert (
0 < ttl <= CELERY_USER_FILE_DELETE_TASK_EXPIRES
), f"Guard key TTL {ttl}s is outside the expected range (0, {CELERY_USER_FILE_DELETE_TASK_EXPIRES}]"
finally:
redis_client.delete(guard_key)
class TestDeleteTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "del_expires_user")
uf = _create_deleting_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_for_user_file_delete, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.DELETE_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_DELETE
assert (
call.kwargs.get("expires") == CELERY_USER_FILE_DELETE_TASK_EXPIRES
), "Task must be submitted with the correct expires value to prevent stale task accumulation"
finally:
redis_client.delete(guard_key)
class TestDeleteWorkerClearsGuardKey:
"""process_single_user_file_delete removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so delete_user_file_impl returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_delete_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file delete lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_delete_lock_key(user_file_id)
delete_lock = redis_client.lock(lock_key, timeout=10)
acquired = delete_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the delete lock for this test"
try:
process_single_user_file_delete.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if delete_lock.owned():
delete_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -297,6 +297,10 @@ def index_batch_params(
class TestDocumentIndexOld:
"""Tests the old DocumentIndex interface."""
# TODO(ENG-3864)(andrei): Re-enable this test.
@pytest.mark.xfail(
reason="Flaky test: Retrieved chunks vary non-deterministically before and after changing user projects and personas. Likely a timing issue with the index being updated."
)
def test_update_single_can_clear_user_projects_and_personas(
self,
document_indices: list[DocumentIndex],

View File

@@ -22,18 +22,26 @@ from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import HybridSearchNormalizationPipeline
from onyx.document_index.opensearch.constants import HybridSearchSubqueryConfiguration
from onyx.document_index.opensearch.opensearch_document_index import (
generate_opensearch_filtered_access_control_list,
)
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
get_min_max_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
get_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import (
get_zscore_normalization_pipeline_name_and_config,
)
from onyx.document_index.opensearch.search import MIN_MAX_NORMALIZATION_PIPELINE_NAME
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -49,6 +57,63 @@ def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) ->
monkeypatch.setattr("onyx.document_index.opensearch.schema.MULTI_TENANT", state)
def _patch_hybrid_search_subquery_configuration(
monkeypatch: pytest.MonkeyPatch, configuration: HybridSearchSubqueryConfiguration
) -> None:
"""
Patches HYBRID_SEARCH_SUBQUERY_CONFIGURATION wherever necessary for this
test file.
Args:
monkeypatch: The test instance's monkeypatch instance, used for
patching.
configuration: The intended state of
HYBRID_SEARCH_SUBQUERY_CONFIGURATION.
"""
monkeypatch.setattr(
"onyx.document_index.opensearch.constants.HYBRID_SEARCH_SUBQUERY_CONFIGURATION",
configuration,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.HYBRID_SEARCH_SUBQUERY_CONFIGURATION",
configuration,
)
def _patch_hybrid_search_normalization_pipeline(
monkeypatch: pytest.MonkeyPatch, pipeline: HybridSearchNormalizationPipeline
) -> None:
"""
Patches HYBRID_SEARCH_NORMALIZATION_PIPELINE wherever necessary for this
test file.
"""
monkeypatch.setattr(
"onyx.document_index.opensearch.constants.HYBRID_SEARCH_NORMALIZATION_PIPELINE",
pipeline,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.HYBRID_SEARCH_NORMALIZATION_PIPELINE",
pipeline,
)
def _patch_opensearch_match_highlights_disabled(
monkeypatch: pytest.MonkeyPatch, disabled: bool
) -> None:
"""
Patches OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED wherever necessary for this
test file.
"""
monkeypatch.setattr(
"onyx.configs.app_configs.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
disabled,
)
monkeypatch.setattr(
"onyx.document_index.opensearch.search.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
disabled,
)
def _create_test_document_chunk(
document_id: str,
content: str,
@@ -144,14 +209,27 @@ def test_client(
@pytest.fixture(scope="function")
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
"""Creates a search pipeline for testing with automatic cleanup."""
min_max_normalization_pipeline_name, min_max_normalization_pipeline_config = (
get_min_max_normalization_pipeline_name_and_config()
)
zscore_normalization_pipeline_name, zscore_normalization_pipeline_config = (
get_zscore_normalization_pipeline_name_and_config()
)
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
pipeline_id=min_max_normalization_pipeline_name,
pipeline_body=min_max_normalization_pipeline_config,
)
test_client.create_search_pipeline(
pipeline_id=zscore_normalization_pipeline_name,
pipeline_body=zscore_normalization_pipeline_config,
)
yield # Test runs here.
try:
test_client.delete_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_id=min_max_normalization_pipeline_name,
)
test_client.delete_search_pipeline(
pipeline_id=zscore_normalization_pipeline_name,
)
except Exception:
pass
@@ -166,7 +244,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Under test.
# Should not raise.
@@ -182,7 +260,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -211,7 +289,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
@@ -225,7 +303,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Under test and postcondition.
# Should return False before creation.
@@ -245,7 +323,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -280,7 +358,7 @@ class TestOpenSearchClient:
},
},
}
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test.
@@ -323,7 +401,7 @@ class TestOpenSearchClient:
"test_field": {"type": "keyword"},
},
}
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test and postcondition.
@@ -358,7 +436,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
# Create once - should succeed.
test_client.create_index(mappings=mappings, settings=settings)
@@ -377,18 +455,19 @@ class TestOpenSearchClient:
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests creating and deleting a search pipeline."""
# Precondition.
pipeline_name, pipeline_config = get_normalization_pipeline_name_and_config()
# Under test and postcondition.
# Should not raise.
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
pipeline_id=pipeline_name,
pipeline_body=pipeline_config,
)
# Under test and postcondition.
# Should not raise.
test_client.delete_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
test_client.delete_search_pipeline(pipeline_id=pipeline_name)
def test_index_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
@@ -400,7 +479,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -428,7 +507,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
docs = [
@@ -459,7 +538,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -487,7 +566,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
original_doc = _create_test_document_chunk(
@@ -522,7 +601,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test and postcondition.
@@ -541,7 +620,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
doc = _create_test_document_chunk(
@@ -577,7 +656,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
@@ -598,7 +677,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index multiple documents.
@@ -674,7 +753,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Create a document to update.
@@ -723,7 +802,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test and postcondition.
@@ -734,22 +813,22 @@ class TestOpenSearchClient:
properties_to_update={"hidden": True},
)
def test_hybrid_search_with_pipeline(
def test_hybrid_search_configurations_and_pipelines(
self,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests hybrid search with a normalization pipeline."""
"""Tests all hybrid search configurations and pipelines."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents.
docs = {
"doc-1": _create_test_document_chunk(
@@ -780,40 +859,62 @@ class TestOpenSearchClient:
# Refresh index to make documents searchable.
test_client.refresh_index()
# Search query.
query_text = "Python programming"
query_vector = _generate_test_vector(0.12)
search_body = DocumentQuery.get_hybrid_search_query(
query_text=query_text,
query_vector=query_vector,
num_hits=5,
tenant_state=tenant_state,
# We're not worried about filtering here. tenant_id in this object
# is not relevant.
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
)
for configuration in HybridSearchSubqueryConfiguration:
_patch_hybrid_search_subquery_configuration(monkeypatch, configuration)
for pipeline in HybridSearchNormalizationPipeline:
_patch_hybrid_search_normalization_pipeline(monkeypatch, pipeline)
pipeline_name, pipeline_config = (
get_normalization_pipeline_name_and_config()
)
test_client.create_search_pipeline(
pipeline_id=pipeline_name,
pipeline_body=pipeline_config,
)
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
# Search query.
query_text = "Python programming"
query_vector = _generate_test_vector(0.12)
search_body = DocumentQuery.get_hybrid_search_query(
query_text=query_text,
query_vector=query_vector,
num_hits=5,
tenant_state=tenant_state,
# We're not worried about filtering here. tenant_id in this object
# is not relevant.
index_filters=IndexFilters(
access_control_list=None, tenant_id=None
),
include_hidden=False,
)
# Postcondition.
assert len(results) == len(docs)
# Assert that all the chunks above are present.
assert all(chunk.document_chunk.document_id in docs.keys() for chunk in results)
# Make sure the chunk contents are preserved.
for i, chunk in enumerate(results):
assert chunk.document_chunk == docs[chunk.document_chunk.document_id]
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert chunk.score
# Make sure there is some kind of match highlight only for the first
# result. The other results are so bad they're not expected to have
# match highlights.
if i == 0:
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=pipeline_name
)
# Postcondition.
assert len(results) == len(docs)
# Assert that all the chunks above are present.
assert all(
chunk.document_chunk.document_id in docs.keys() for chunk in results
)
# Make sure the chunk contents are preserved.
for i, chunk in enumerate(results):
expected = docs[chunk.document_chunk.document_id]
assert chunk.document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(expected, k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert chunk.score
# Make sure there is some kind of match highlight only for the first
# result. The other results are so bad they're not expected to have
# match highlights.
if i == 0:
assert chunk.match_highlights.get(CONTENT_FIELD_NAME, [])
def test_search_empty_index(
self,
@@ -828,7 +929,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Note no documents were indexed.
@@ -845,11 +946,10 @@ class TestOpenSearchClient:
index_filters=IndexFilters(access_control_list=None, tenant_id=None),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
results = test_client.search(body=search_body, search_pipeline_id=pipeline_name)
# Postcondition.
assert len(results) == 0
@@ -865,12 +965,13 @@ class TestOpenSearchClient:
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
@@ -948,11 +1049,10 @@ class TestOpenSearchClient:
),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
results = test_client.search(body=search_body, search_pipeline_id=pipeline_name)
# Postcondition.
# Should only get the public, non-hidden document, and the private
@@ -962,7 +1062,12 @@ class TestOpenSearchClient:
# ordered; we're just assuming which doc will be the first result here.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == docs["public-doc"]
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
@@ -970,7 +1075,12 @@ class TestOpenSearchClient:
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == docs["private-doc-user-a"]
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
@@ -986,11 +1096,12 @@ class TestOpenSearchClient:
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with varying relevance to the query.
@@ -1067,11 +1178,10 @@ class TestOpenSearchClient:
index_filters=IndexFilters(access_control_list=[], tenant_id=None),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
results = test_client.search(
body=search_body, search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
results = test_client.search(body=search_body, search_pipeline_id=pipeline_name)
# Postcondition.
# Should only get public, non-hidden documents (3 out of 5).
@@ -1118,7 +1228,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Although very unlikely in practice, let's use the same doc ID just to
@@ -1211,7 +1321,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Don't index any documents.
@@ -1238,7 +1348,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents.
@@ -1306,7 +1416,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
@@ -1383,7 +1493,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index docs with various ages.
@@ -1441,15 +1551,16 @@ class TestOpenSearchClient:
),
include_hidden=False,
)
pipeline_name, _ = get_normalization_pipeline_name_and_config()
# Under test.
last_week_results = test_client.search(
body=last_week_search_body,
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
search_pipeline_id=pipeline_name,
)
last_six_months_results = test_client.search(
body=last_six_months_search_body,
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
search_pipeline_id=pipeline_name,
)
# Postcondition.
@@ -1474,7 +1585,7 @@ class TestOpenSearchClient:
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents, one hidden one not.
@@ -1523,4 +1634,281 @@ class TestOpenSearchClient:
for result in results:
# Note each result must be from doc 1, which is not hidden.
expected_result = doc1_chunks[result.document_chunk.chunk_index]
assert result.document_chunk == expected_result
assert result.document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(expected_result, k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
def test_keyword_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests keyword search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
# Tests that we don't return documents that don't match keywords at
# all, even if they match filters.
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
document_id="private-but-not-relevant-doc-user-a",
chunk_index=0,
content="This text should not match the query at all",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
# Should not match private-but-not-relevant-doc-user-a.
query_text = "document content"
search_body = DocumentQuery.get_keyword_search_query(
query_text=query_text,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# This should be the highest-ranked result, as a higher percentage of
# the content matches the query.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
# Make sure score reporting seems reasonable (it should not be None
# or 0).
assert results[0].score
# Make sure there is some kind of match highlight.
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
assert results[1].score < results[0].score
def test_semantic_search(
self,
test_client: OpenSearchIndexClient,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests semantic search with filters for ACL, hidden documents, and tenant
isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
chunk_index=0,
content="Public document content",
hidden=False,
tenant_state=tenant_x,
# Make this identical to the query vector to test that this
# result is returned first.
content_vector=_generate_test_vector(0.6),
),
"hidden-doc": _create_test_document_chunk(
document_id="hidden-doc",
chunk_index=0,
content="Hidden document content, spooky",
hidden=True,
tenant_state=tenant_x,
),
"private-doc-user-a": _create_test_document_chunk(
document_id="private-doc-user-a",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
# Make this different from the query vector to test that this
# result is returned second.
content_vector=_generate_test_vector(0.5),
),
"private-doc-user-b": _create_test_document_chunk(
document_id="private-doc-user-b",
chunk_index=0,
content="Private document content, btw my SSN is 987-65-4321",
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
),
),
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
# Refresh index to make documents searchable.
test_client.refresh_index()
query_vector = _generate_test_vector(0.6)
search_body = DocumentQuery.get_semantic_search_query(
query_embedding=query_vector,
num_hits=5,
tenant_state=tenant_x,
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
tenant_id=None,
),
include_hidden=False,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
# Should only get the public, non-hidden document, and the private
# document for which the user has access.
assert len(results) == 2
# We explicitly expect this to be the highest-ranked result.
assert results[0].document_chunk.document_id == "public-doc"
# Make sure the chunk contents are preserved.
assert results[0].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["public-doc"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[0].score == 1.0
# Same for the second result.
assert results[1].document_chunk.document_id == "private-doc-user-a"
assert results[1].document_chunk == DocumentChunkWithoutVectors(
**{
k: getattr(docs["private-doc-user-a"], k)
for k in DocumentChunkWithoutVectors.model_fields
}
)
assert results[1].score
assert 0.0 < results[1].score < 1.0

View File

@@ -31,7 +31,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
)
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.configs.constants import SOURCE_TYPE
from onyx.context.search.models import IndexFilters
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
@@ -44,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
@@ -70,6 +70,7 @@ from onyx.document_index.vespa_constants import SOURCE_LINKS
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.document_index.vespa_constants import USER_PROJECT
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
from tests.external_dependency_unit.full_setup import ensure_full_deployment_setup
@@ -78,24 +79,22 @@ CHUNK_COUNT = 5
def _get_document_chunks_from_opensearch(
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient,
document_id: str,
tenant_state: TenantState,
) -> list[DocumentChunk]:
opensearch_client.refresh_index()
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
query_body = DocumentQuery.get_from_document_id_query(
document_id=document_id,
tenant_state=TenantState(tenant_id=current_tenant_id, multitenant=False),
index_filters=filters,
include_hidden=False,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
)
search_hits = opensearch_client.search(
body=query_body,
search_pipeline_id=None,
)
return [search_hit.document_chunk for search_hit in search_hits]
results: list[DocumentChunk] = []
for i in range(CHUNK_COUNT):
document_chunk_id: str = get_opensearch_doc_chunk_id(
tenant_state=tenant_state,
document_id=document_id,
chunk_index=i,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
)
result = opensearch_client.get_document(document_chunk_id)
results.append(result)
return results
def _delete_document_chunks_from_opensearch(
@@ -452,10 +451,13 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Under test.
result = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -477,7 +479,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -522,6 +524,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
@@ -536,7 +541,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
assert result_1 is True
@@ -559,7 +564,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Under test.
# Run the remainder of the migration.
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -583,7 +588,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -630,6 +635,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
@@ -646,7 +654,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
assert result_1 is True
@@ -691,7 +699,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
),
):
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -728,7 +736,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
),
):
result_3 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -752,7 +760,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -840,24 +848,25 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
chunk["content"] = (
f"Different content {chunk[CHUNK_ID]} for {test_documents[0].id}"
)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
chunks_for_document_in_opensearch, _ = (
transform_vespa_chunks_to_opensearch_chunks(
document_in_opensearch,
TenantState(tenant_id=get_current_tenant_id(), multitenant=False),
tenant_state,
{},
)
)
opensearch_client.bulk_index_documents(
documents=chunks_for_document_in_opensearch,
tenant_state=TenantState(
tenant_id=get_current_tenant_id(), multitenant=False
),
tenant_state=tenant_state,
update_if_exists=True,
)
# Under test.
result = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -878,7 +887,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -922,11 +931,14 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
# Under test.
# First run.
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -947,7 +959,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
@@ -960,7 +972,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Under test.
# Second run.
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
tenant_id=tenant_state.tenant_id
)
# Postcondition.
@@ -982,7 +994,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
opensearch_client, document.id, tenant_state
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)

View File

@@ -1219,15 +1219,16 @@ def test_code_interpreter_receives_chat_files(
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Verify: file uploaded, code executed via streaming, staged file cleaned up
# Verify: file uploaded and code executed via streaming.
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
delete_requests = mock_ci_server.get_requests(method="DELETE")
assert len(delete_requests) == 1
assert delete_requests[0].path.startswith("/v1/files/")
# Staged input files are intentionally NOT deleted — PythonTool caches their
# file IDs across agent-loop iterations to avoid re-uploading on every call.
# The code interpreter cleans them up via its own TTL.
assert len(mock_ci_server.get_requests(method="DELETE")) == 0
execute_body = mock_ci_server.get_requests(
method="POST", path="/v1/execute/stream"

View File

@@ -1,9 +1,13 @@
"""Tests for llm_loop.py, specifically the construct_message_history function."""
"""Tests for llm_loop.py, including history construction and empty-response paths."""
from unittest.mock import Mock
import pytest
from onyx.chat.llm_loop import _build_empty_llm_response_error
from onyx.chat.llm_loop import _try_fallback_tool_extraction
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
@@ -13,6 +17,7 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
from onyx.llm.interfaces import LLMConfig
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallKickoff
@@ -1167,3 +1172,57 @@ class TestFallbackToolExtraction:
assert result is llm_step_result
assert attempted is False
class TestEmptyLlmResponseClassification:
def _make_llm(self, provider: str = "openai", model: str = "gpt-5.2") -> Mock:
llm = Mock()
llm.config = LLMConfig(
model_provider=provider,
model_name=model,
temperature=0.0,
max_input_tokens=4096,
)
return llm
def test_openai_empty_stream_is_classified_as_budget_exceeded(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr("onyx.chat.llm_loop.is_true_openai_model", lambda *_: True)
err = _build_empty_llm_response_error(
llm=self._make_llm(),
llm_step_result=LlmStepResult(
reasoning=None,
answer=None,
tool_calls=None,
raw_answer=None,
),
tool_choice=ToolChoiceOptions.AUTO,
)
assert isinstance(err, EmptyLLMResponseError)
assert err.error_code == "BUDGET_EXCEEDED"
assert err.is_retryable is False
assert "quota" in err.client_error_msg.lower()
def test_reasoning_only_response_uses_generic_empty_response_error(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr("onyx.chat.llm_loop.is_true_openai_model", lambda *_: True)
err = _build_empty_llm_response_error(
llm=self._make_llm(),
llm_step_result=LlmStepResult(
reasoning="scratchpad only",
answer=None,
tool_calls=None,
raw_answer=None,
),
tool_choice=ToolChoiceOptions.AUTO,
)
assert isinstance(err, EmptyLLMResponseError)
assert err.error_code == "EMPTY_LLM_RESPONSE"
assert err.is_retryable is True
assert "quota" not in err.client_error_msg.lower()

View File

@@ -0,0 +1,133 @@
import pytest
from onyx.chat.process_message import _resolve_query_processing_hook_result
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
answer = "The answer is Paris [[1]](https://example.com/doc)."
assert remove_answer_citations(answer) == "The answer is Paris."
def test_remove_answer_citations_strips_empty_markdown_citation() -> None:
answer = "The answer is Paris [[1]]()."
assert remove_answer_citations(answer) == "The answer is Paris."
def test_remove_answer_citations_strips_citation_with_parentheses_in_url() -> None:
answer = (
"The answer is Paris "
"[[1]](https://en.wikipedia.org/wiki/Function_(mathematics))."
)
assert remove_answer_citations(answer) == "The answer is Paris."
def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None:
answer = (
"See [reference](https://example.com/Function_(mathematics)) "
"for context [[1]](https://en.wikipedia.org/wiki/Function_(mathematics))."
)
assert (
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_resolve_query_processing_hook_result)
# ---------------------------------------------------------------------------
def test_wrong_model_type_raises_internal_error() -> None:
"""If the executor ever returns an unexpected BaseModel type, raise INTERNAL_ERROR
rather than an AssertionError or AttributeError."""
from pydantic import BaseModel as PydanticBaseModel
class _OtherModel(PydanticBaseModel):
pass
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(_OtherModel(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _resolve_query_processing_hook_result(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _resolve_query_processing_hook_result(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_whitespace_only_query_raises_query_rejected() -> None:
"""Whitespace-only string is truthy but meaningless — must be treated as rejection."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=" "), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "Your query was rejected." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _resolve_query_processing_hook_result(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -3,7 +3,10 @@ from unittest.mock import Mock
import pytest
from onyx.chat import process_message
from onyx.chat.models import AnswerStream
from onyx.chat.models import StreamingError
from onyx.configs import app_configs
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -35,3 +38,26 @@ def test_mock_llm_response_requires_integration_mode() -> None:
db_session=Mock(),
)
)
def test_gather_stream_returns_empty_answer_when_streaming_error_only() -> None:
packets: AnswerStream = iter(
[
MessageResponseIDInfo(
user_message_id=None,
reserved_assistant_message_id=42,
),
StreamingError(
error="OpenAI quota exceeded",
error_code="BUDGET_EXCEEDED",
is_retryable=False,
),
]
)
result = process_message.gather_stream(packets)
assert result.answer == ""
assert result.answer_citationless == ""
assert result.error_msg == "OpenAI quota exceeded"
assert result.message_id == 42

View File

@@ -0,0 +1,63 @@
"""
Unit test verifying that the upload API path sends tasks with expires=.
The upload_files_to_user_files_with_indexing function must include expires=
on every send_task call to prevent phantom task accumulation if the worker
is down or slow.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.models import UserFile
from onyx.db.projects import upload_files_to_user_files_with_indexing
def _make_mock_user_file() -> MagicMock:
uf = MagicMock(spec=UserFile)
uf.id = str(uuid4())
return uf
@patch("onyx.db.projects.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.db.projects.create_user_files")
@patch(
"onyx.background.celery.versioned_apps.client.app",
new_callable=MagicMock,
)
def test_send_task_includes_expires(
mock_client_app: MagicMock,
mock_create: MagicMock,
mock_tenant: MagicMock, # noqa: ARG001
) -> None:
"""Every send_task call from the upload path must include expires=."""
user_files = [_make_mock_user_file(), _make_mock_user_file()]
mock_create.return_value = MagicMock(
user_files=user_files,
rejected_files=[],
id_to_temp_id={},
)
mock_user = MagicMock()
mock_db_session = MagicMock()
upload_files_to_user_files_with_indexing(
files=[],
project_id=None,
user=mock_user,
temp_id_map=None,
db_session=mock_db_session,
)
assert mock_client_app.send_task.call_count == len(user_files)
for call in mock_client_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires") == CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), "send_task must include expires= to prevent phantom task accumulation"

View File

@@ -0,0 +1,45 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Parent 2 0 R
>>
endobj
xref
0 5
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000113 00000 n
0000000162 00000 n
trailer
<<
/Size 5
/Root 3 0 R
/Info 1 0 R
>>
startxref
256
%%EOF

View File

@@ -0,0 +1,89 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 2
/Kids [ 4 0 R 6 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 47
>>
stream
BT /F1 12 Tf 50 150 Td (Page one content) Tj ET
endstream
endobj
6 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 7 0 R
/Parent 2 0 R
>>
endobj
7 0 obj
<<
/Length 47
>>
stream
BT /F1 12 Tf 50 150 Td (Page two content) Tj ET
endstream
endobj
xref
0 8
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000119 00000 n
0000000168 00000 n
0000000349 00000 n
0000000446 00000 n
0000000627 00000 n
trailer
<<
/Size 8
/Root 3 0 R
/Info 1 0 R
>>
startxref
724
%%EOF

View File

@@ -0,0 +1,62 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
BT /F1 12 Tf 50 150 Td (Hello World) Tj ET
endstream
endobj
xref
0 6
0000000000 65535 f
0000000015 00000 n
0000000054 00000 n
0000000113 00000 n
0000000162 00000 n
0000000343 00000 n
trailer
<<
/Size 6
/Root 3 0 R
/Info 1 0 R
>>
startxref
435
%%EOF

View File

@@ -0,0 +1,64 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer (pypdf)
/Title (My Title)
/Author (Jane Doe)
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 35
>>
stream
BT /F1 12 Tf 50 150 Td (test) Tj ET
endstream
endobj
xref
0 6
0000000000 65535 f
0000000015 00000 n
0000000091 00000 n
0000000150 00000 n
0000000199 00000 n
0000000380 00000 n
trailer
<<
/Size 6
/Root 3 0 R
/Info 1 0 R
>>
startxref
465
%%EOF

View File

@@ -0,0 +1,89 @@
"""
Unit tests for image summarization error handling.
Verifies that:
1. LLM errors produce actionable error messages (not base64 dumps)
2. Unsupported MIME type logs include the magic bytes and size
3. The ValueError raised on LLM failure preserves the original exception
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.file_processing.image_summarization import _summarize_image
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
class TestSummarizeImageErrorMessage:
"""_summarize_image must not dump base64 image data into error messages."""
def test_error_message_contains_exception_type_not_base64(self) -> None:
"""The ValueError should contain the original exception info, not message payloads."""
mock_llm = MagicMock()
mock_llm.invoke.side_effect = RuntimeError("Connection timeout")
# A fake base64-encoded image string (should NOT appear in the error)
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg..."
with pytest.raises(ValueError, match="RuntimeError: Connection timeout"):
_summarize_image(fake_encoded, mock_llm, query="test")
def test_error_message_does_not_contain_base64(self) -> None:
"""Ensure base64 data is never included in the error message."""
mock_llm = MagicMock()
mock_llm.invoke.side_effect = RuntimeError("API error")
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA"
with pytest.raises(ValueError) as exc_info:
_summarize_image(fake_encoded, mock_llm)
error_str = str(exc_info.value)
assert "base64" not in error_str
assert "iVBOR" not in error_str
def test_original_exception_is_chained(self) -> None:
"""The ValueError should chain the original exception via __cause__."""
mock_llm = MagicMock()
original = RuntimeError("upstream failure")
mock_llm.invoke.side_effect = original
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
assert exc_info.value.__cause__ is original
class TestUnsupportedMimeTypeLogging:
"""summarize_image_with_error_handling should log useful info for unsupported formats."""
@patch(
"onyx.file_processing.image_summarization.summarize_image_pipeline",
side_effect=__import__(
"onyx.file_processing.image_summarization",
fromlist=["UnsupportedImageFormatError"],
).UnsupportedImageFormatError("unsupported"),
)
def test_logs_magic_bytes_and_size(
self, mock_pipeline: MagicMock # noqa: ARG002
) -> None:
"""The info log should include magic bytes hex and image size."""
mock_llm = MagicMock()
# TIFF magic bytes (not in the supported list)
image_data = b"\x49\x49\x2a\x00" + b"\x00" * 100
with patch("onyx.file_processing.image_summarization.logger") as mock_logger:
result = summarize_image_with_error_handling(
llm=mock_llm,
image_data=image_data,
context_name="test_image.tiff",
)
assert result is None
mock_logger.info.assert_called_once()
log_args = mock_logger.info.call_args
# Check the format string args contain magic bytes and size
assert "49492a00" in str(log_args)
assert "104" in str(log_args) # 4 + 100 bytes

View File

@@ -0,0 +1,141 @@
"""
Unit tests verifying that LiteLLM error details are extracted and surfaced
in image summarization error messages.
When the LLM call fails, the error handler should include the status_code,
llm_provider, and model from LiteLLM exceptions so operators can diagnose
the root cause (rate limit, content filter, unsupported vision, etc.)
without needing to dig through LiteLLM internals.
"""
from unittest.mock import MagicMock
import pytest
from onyx.file_processing.image_summarization import _summarize_image
def _make_litellm_style_error(
*,
message: str = "API error",
status_code: int | None = None,
llm_provider: str | None = None,
model: str | None = None,
) -> RuntimeError:
"""Create an exception with LiteLLM-style attributes."""
exc = RuntimeError(message)
if status_code is not None:
exc.status_code = status_code # type: ignore[attr-defined]
if llm_provider is not None:
exc.llm_provider = llm_provider # type: ignore[attr-defined]
if model is not None:
exc.model = model # type: ignore[attr-defined]
return exc
class TestLiteLLMErrorExtraction:
"""Verify that LiteLLM error attributes are included in the ValueError."""
def test_status_code_included(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Content filter triggered",
status_code=400,
llm_provider="azure",
model="gpt-4o",
)
with pytest.raises(ValueError, match="status_code=400"):
_summarize_image("data:image/png;base64,abc", mock_llm)
def test_llm_provider_included(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Bad request",
status_code=400,
llm_provider="azure",
)
with pytest.raises(ValueError, match="llm_provider=azure"):
_summarize_image("data:image/png;base64,abc", mock_llm)
def test_model_included(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Bad request",
model="gpt-4o",
)
with pytest.raises(ValueError, match="model=gpt-4o"):
_summarize_image("data:image/png;base64,abc", mock_llm)
def test_all_fields_in_single_message(self) -> None:
mock_llm = MagicMock()
mock_llm.invoke.side_effect = _make_litellm_style_error(
message="Rate limit exceeded",
status_code=429,
llm_provider="azure",
model="gpt-4o",
)
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
msg = str(exc_info.value)
assert "status_code=429" in msg
assert "llm_provider=azure" in msg
assert "model=gpt-4o" in msg
assert "Rate limit exceeded" in msg
def test_plain_exception_without_litellm_attrs(self) -> None:
"""Non-LiteLLM exceptions should still produce a useful message."""
mock_llm = MagicMock()
mock_llm.invoke.side_effect = ConnectionError("Connection refused")
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
msg = str(exc_info.value)
assert "ConnectionError" in msg
assert "Connection refused" in msg
# Should not contain status_code/llm_provider/model
assert "status_code" not in msg
assert "llm_provider" not in msg
def test_no_base64_in_error(self) -> None:
"""Error messages must not contain the full base64 image payload.
Some LiteLLM exceptions echo the request body (including base64 images)
in their message. The truncation guard ensures the bulk of such a
payload is stripped from the re-raised ValueError.
"""
mock_llm = MagicMock()
# Build a long base64-like payload that exceeds the 512-char truncation
fake_b64_payload = "iVBORw0KGgo" * 100 # ~1100 chars
fake_b64 = f"data:image/png;base64,{fake_b64_payload}"
mock_llm.invoke.side_effect = RuntimeError(
f"Request failed for payload: {fake_b64}"
)
with pytest.raises(ValueError) as exc_info:
_summarize_image(fake_b64, mock_llm)
msg = str(exc_info.value)
# The full payload must not appear (truncation should have kicked in)
assert fake_b64_payload not in msg
assert "truncated" in msg
def test_long_error_message_truncated(self) -> None:
"""Exception messages longer than 512 chars are truncated."""
mock_llm = MagicMock()
long_msg = "x" * 1000
mock_llm.invoke.side_effect = RuntimeError(long_msg)
with pytest.raises(ValueError) as exc_info:
_summarize_image("data:image/png;base64,abc", mock_llm)
msg = str(exc_info.value)
assert "truncated" in msg
# The full 1000-char string should not appear
assert long_msg not in msg

View File

@@ -0,0 +1,124 @@
"""Unit tests for pypdf-dependent PDF processing functions.
Tests cover:
- read_pdf_file: text extraction, metadata, encrypted PDFs, image extraction
- pdf_to_text: convenience wrapper
- is_pdf_protected: password protection detection
Fixture PDFs live in ./fixtures/ and are pre-built so the test layer has no
dependency on pypdf internals (pypdf.generic).
"""
from io import BytesIO
from pathlib import Path
from onyx.file_processing.extract_file_text import pdf_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.password_validation import is_pdf_protected
FIXTURES = Path(__file__).parent / "fixtures"
def _load(name: str) -> BytesIO:
return BytesIO((FIXTURES / name).read_bytes())
# ── read_pdf_file ────────────────────────────────────────────────────────
class TestReadPdfFile:
def test_basic_text_extraction(self) -> None:
text, _, images = read_pdf_file(_load("simple.pdf"))
assert "Hello World" in text
assert images == []
def test_multi_page_text_extraction(self) -> None:
text, _, _ = read_pdf_file(_load("multipage.pdf"))
assert "Page one content" in text
assert "Page two content" in text
def test_metadata_extraction(self) -> None:
_, pdf_metadata, _ = read_pdf_file(_load("with_metadata.pdf"))
assert pdf_metadata.get("Title") == "My Title"
assert pdf_metadata.get("Author") == "Jane Doe"
def test_encrypted_pdf_with_correct_password(self) -> None:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="pass123")
assert "Secret Content" in text
def test_encrypted_pdf_without_password(self) -> None:
text, _, _ = read_pdf_file(_load("encrypted.pdf"))
assert text == ""
def test_encrypted_pdf_with_wrong_password(self) -> None:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
assert text == ""
def test_empty_pdf(self) -> None:
text, _, _ = read_pdf_file(_load("empty.pdf"))
assert text.strip() == ""
def test_invalid_pdf_returns_empty(self) -> None:
text, _, images = read_pdf_file(BytesIO(b"this is not a pdf"))
assert text == ""
assert images == []
def test_image_extraction_disabled_by_default(self) -> None:
_, _, images = read_pdf_file(_load("with_image.pdf"))
assert images == []
def test_image_extraction_collects_images(self) -> None:
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert len(images) == 1
img_bytes, img_name = images[0]
assert len(img_bytes) > 0
assert img_name # non-empty name
def test_image_callback_streams_instead_of_collecting(self) -> None:
"""With image_callback, images are streamed via callback and not accumulated."""
collected: list[tuple[bytes, str]] = []
def callback(data: bytes, name: str) -> None:
collected.append((data, name))
_, _, images = read_pdf_file(
_load("with_image.pdf"), extract_images=True, image_callback=callback
)
# Callback received the image
assert len(collected) == 1
assert len(collected[0][0]) > 0
# Returned list is empty when callback is used
assert images == []
# ── pdf_to_text ──────────────────────────────────────────────────────────
class TestPdfToText:
def test_returns_text(self) -> None:
assert "Hello World" in pdf_to_text(_load("simple.pdf"))
def test_with_password(self) -> None:
assert "Secret Content" in pdf_to_text(
_load("encrypted.pdf"), pdf_pass="pass123"
)
def test_encrypted_without_password_returns_empty(self) -> None:
assert pdf_to_text(_load("encrypted.pdf")) == ""
# ── is_pdf_protected ─────────────────────────────────────────────────────
class TestIsPdfProtected:
def test_unprotected_pdf(self) -> None:
assert is_pdf_protected(_load("simple.pdf")) is False
def test_protected_pdf(self) -> None:
assert is_pdf_protected(_load("encrypted.pdf")) is True
def test_preserves_file_position(self) -> None:
pdf = _load("simple.pdf")
pdf.seek(42)
is_pdf_protected(pdf)
assert pdf.tell() == 42

View File

@@ -0,0 +1,19 @@
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
def test_init_subclass_raises_for_missing_attrs() -> None:
with pytest.raises(TypeError, match="must define class attributes"):
class IncompleteSpec(HookPointSpec):
hook_point = HookPoint.QUERY_PROCESSING
# missing display_name, description, payload_model, response_model, etc.
class _Payload(BaseModel):
pass
payload_model = _Payload
response_model = _Payload

View File

@@ -0,0 +1,678 @@
"""Unit tests for the hook executor."""
import json
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
def _make_hook(
*,
is_active: bool = True,
endpoint_url: str | None = "https://hook.example.com/query",
api_key: MagicMock | None = None,
timeout_seconds: float = 5.0,
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
hook.endpoint_url = endpoint_url
hook.api_key = api_key
hook.timeout_seconds = timeout_seconds
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
def _make_api_key(value: str) -> MagicMock:
api_key = MagicMock()
api_key.get_value.return_value = value
return api_key
def _make_response(
*,
status_code: int = 200,
json_return: Any = _RESPONSE_PAYLOAD,
json_side_effect: Exception | None = None,
) -> MagicMock:
"""Build a response mock with controllable json() behaviour."""
response = MagicMock()
response.status_code = status_code
if json_side_effect is not None:
response.json.side_effect = json_side_effect
else:
response.json.return_value = json_return
return response
def _setup_client(
mock_client_cls: MagicMock,
*,
response: MagicMock | None = None,
side_effect: Exception | None = None,
) -> MagicMock:
"""Wire up the httpx.Client mock and return the inner client.
If side_effect is an httpx.HTTPStatusError, it is raised from
raise_for_status() (matching real httpx behaviour) and post() returns a
response mock with the matching status_code set. All other exceptions are
raised directly from post().
"""
mock_client = MagicMock()
if isinstance(side_effect, httpx.HTTPStatusError):
error_response = MagicMock()
error_response.status_code = side_effect.response.status_code
error_response.raise_for_status.side_effect = side_effect
mock_client.post = MagicMock(return_value=error_response)
else:
mock_client.post = MagicMock(
side_effect=side_effect, return_value=response if not side_effect else None
)
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def db_session() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# Early-exit guards (no HTTP call, no DB writes)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hooks_available,hook",
[
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
hooks_available: bool,
hook: MagicMock | None,
) -> None:
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSkipped)
mock_update.assert_not_called()
mock_log.assert_not_called()
# ---------------------------------------------------------------------------
# Successful HTTP call
# ---------------------------------------------------------------------------
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
hook = _make_hook()
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
"""Deduplication guard: a hook already at is_reachable=True that succeeds
must not trigger a DB write."""
hook = _make_hook(is_reachable=True)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
mock_update.assert_not_called()
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(json_return=["unexpected", "list"]),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-dict" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
"""response.json() raising must be treated as failure with SOFT strategy.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-JSON" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
# ---------------------------------------------------------------------------
# HTTP failure paths
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"exception,fail_strategy,expected_type,expected_is_reachable",
[
# NetworkError → is_reachable=False
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="connect_error_soft",
),
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.HARD,
OnyxError,
False,
id="connect_error_hard",
),
# 401/403 → is_reachable=False (api_key revoked)
pytest.param(
httpx.HTTPStatusError(
"401",
request=MagicMock(),
response=MagicMock(status_code=401, text="Unauthorized"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="auth_401_soft",
),
pytest.param(
httpx.HTTPStatusError(
"403",
request=MagicMock(),
response=MagicMock(status_code=403, text="Forbidden"),
),
HookFailStrategy.HARD,
OnyxError,
False,
id="auth_403_hard",
),
# TimeoutException → no is_reachable write (None)
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="timeout_soft",
),
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.HARD,
OnyxError,
None,
id="timeout_hard",
),
# Other HTTP errors → no is_reachable write (None)
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="http_status_error_soft",
),
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.HARD,
OnyxError,
None,
id="http_status_error_hard",
),
],
)
def test_http_failure_paths(
db_session: MagicMock,
exception: Exception,
fail_strategy: HookFailStrategy,
expected_type: type,
expected_is_reachable: bool | None,
) -> None:
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, expected_type)
if expected_is_reachable is None:
mock_update.assert_not_called()
else:
mock_update.assert_called_once()
_, kwargs = mock_update.call_args
assert kwargs["is_reachable"] is expected_is_reachable
# ---------------------------------------------------------------------------
# Authorization header
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"api_key_value,expect_auth_header",
[
pytest.param("secret-token", True, id="api_key_present"),
pytest.param(None, False, id="api_key_absent"),
],
)
def test_authorization_header(
db_session: MagicMock,
api_key_value: str | None,
expect_auth_header: bool,
) -> None:
api_key = _make_api_key(api_key_value) if api_key_value else None
hook = _make_hook(api_key=api_key)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
_, call_kwargs = mock_client.post.call_args
if expect_auth_header:
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
else:
assert "Authorization" not in call_kwargs["headers"]
# ---------------------------------------------------------------------------
# Persist session failure
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"http_exception,expect_onyx_error",
[
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expect_onyx_error: bool,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response() if not http_exception else None,
side_effect=http_exception,
)
if expect_onyx_error:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
def _make_strict_spec() -> MagicMock:
spec = MagicMock()
spec.response_model = _StrictResponse
return spec
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"onyx.hooks.executor.get_hook_point_spec",
return_value=_make_strict_spec(),
),
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
"""is_reachable update failing (e.g. concurrent hook deletion) must not
prevent the execution log from being written.
Simulates the production failure path: update_hook__no_commit raises
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
between the initial lookup and the reachable update.
"""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
mock_log.assert_called_once()

View File

@@ -0,0 +1,86 @@
import pytest
from pydantic import ValidationError
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookUpdateRequest
def test_hook_update_request_rejects_empty() -> None:
# No fields supplied at all
with pytest.raises(ValidationError, match="At least one field must be provided"):
HookUpdateRequest()
def test_hook_update_request_rejects_null_name_when_only_field() -> None:
# Explicitly setting name=None is rejected as name cannot be cleared
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=None)
def test_hook_update_request_accepts_single_field() -> None:
req = HookUpdateRequest(name="new name")
assert req.name == "new name"
def test_hook_update_request_accepts_partial_fields() -> None:
req = HookUpdateRequest(fail_strategy=HookFailStrategy.SOFT, timeout_seconds=10.0)
assert req.fail_strategy == HookFailStrategy.SOFT
assert req.timeout_seconds == 10.0
assert req.name is None
def test_hook_update_request_rejects_null_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=None, fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_empty_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name="", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_null_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url=None, fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_empty_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url="", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_allows_null_api_key() -> None:
# api_key=null is valid — means "clear the api key"
req = HookUpdateRequest(api_key=None)
assert req.api_key is None
assert "api_key" in req.model_fields_set
def test_hook_update_request_rejects_whitespace_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=" ", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_whitespace_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url=" ", fail_strategy=HookFailStrategy.SOFT)
def test_hook_create_request_rejects_whitespace_name() -> None:
with pytest.raises(ValidationError, match="whitespace-only"):
HookCreateRequest(
name=" ",
hook_point=HookPoint.QUERY_PROCESSING,
endpoint_url="https://example.com/hook",
)
def test_hook_create_request_rejects_whitespace_endpoint_url() -> None:
with pytest.raises(ValidationError, match="whitespace-only"):
HookCreateRequest(
name="my hook",
hook_point=HookPoint.QUERY_PROCESSING,
endpoint_url=" ",
)

View File

@@ -0,0 +1,62 @@
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.query_processing import QueryProcessingSpec
def test_hook_point_is_query_processing() -> None:
assert QueryProcessingSpec().hook_point == HookPoint.QUERY_PROCESSING
def test_default_fail_strategy_is_hard() -> None:
assert QueryProcessingSpec().default_fail_strategy == HookFailStrategy.HARD
def test_default_timeout_seconds() -> None:
# User is actively waiting — 5s is the documented contract for this hook point
assert QueryProcessingSpec().default_timeout_seconds == 5.0
def test_input_schema_required_fields() -> None:
schema = QueryProcessingSpec().input_schema
assert schema["type"] == "object"
required = schema["required"]
assert "query" in required
assert "user_email" in required
assert "chat_session_id" in required
def test_input_schema_chat_session_id_is_string() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert props["chat_session_id"]["type"] == "string"
def test_input_schema_query_is_string() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert props["query"]["type"] == "string"
def test_input_schema_user_email_is_nullable() -> None:
props = QueryProcessingSpec().input_schema["properties"]
# Pydantic v2 emits anyOf for nullable fields
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
def test_output_schema_query_is_optional() -> None:
# query defaults to None (absent = reject); not required in the schema
schema = QueryProcessingSpec().output_schema
assert "query" not in schema.get("required", [])
def test_output_schema_query_is_nullable() -> None:
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
props = QueryProcessingSpec().output_schema["properties"]
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
def test_output_schema_rejection_message_is_optional() -> None:
schema = QueryProcessingSpec().output_schema
assert "rejection_message" not in schema.get("required", [])
def test_input_schema_no_additional_properties() -> None:
assert QueryProcessingSpec().input_schema.get("additionalProperties") is False

View File

@@ -0,0 +1,47 @@
import pytest
from onyx.db.enums import HookPoint
from onyx.hooks import registry as registry_module
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.registry import validate_registry
def test_registry_covers_all_hook_points() -> None:
"""Every HookPoint enum member must have a registered spec."""
assert {s.hook_point for s in get_all_specs()} == set(
HookPoint
), f"Missing specs for: {set(HookPoint) - {s.hook_point for s in get_all_specs()}}"
def test_get_hook_point_spec_returns_correct_spec() -> None:
for hook_point in HookPoint:
spec = get_hook_point_spec(hook_point)
assert spec.hook_point == hook_point
def test_get_all_specs_returns_all() -> None:
specs = get_all_specs()
assert len(specs) == len(HookPoint)
assert {s.hook_point for s in specs} == set(HookPoint)
def test_get_hook_point_spec_raises_for_unregistered(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""get_hook_point_spec raises ValueError when a hook point has no spec."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
with pytest.raises(ValueError, match="No spec registered for hook point"):
get_hook_point_spec(HookPoint.QUERY_PROCESSING)
def test_validate_registry_passes() -> None:
validate_registry() # should not raise with the real registry
def test_validate_registry_raises_for_incomplete(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(registry_module, "_REGISTRY", {})
with pytest.raises(RuntimeError, match="Hook point\\(s\\) have no registered spec"):
validate_registry()

View File

@@ -256,7 +256,6 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -412,7 +411,6 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -1431,3 +1429,36 @@ def test_strip_tool_content_merges_consecutive_tool_results() -> None:
assert "sunny 72F" in merged
assert "tc_2" in merged
assert "headline news" in merged
def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> None:
"""Regression test for providers (e.g. Fireworks) that reject tool_choice=null.
When no tools are provided, tool_choice must not be forwarded to
litellm.completion() at all — not even as None.
"""
messages: LanguageModelInput = [UserMessage(content="Hello!")]
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello!"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = mock_stream_chunks
default_multi_llm.invoke(messages, tools=None)
_, kwargs = mock_completion.call_args
assert (
"tool_choice" not in kwargs
), "tool_choice must not be sent to providers when no tools are provided"

View File

@@ -0,0 +1,130 @@
"""
Unit tests for vision model selection logging in get_default_llm_with_vision.
Verifies that operators get clear feedback about:
1. Which vision model was selected and why
2. When the default vision model doesn't support image input
3. When no vision-capable model exists at all
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.llm.factory import get_default_llm_with_vision
_FACTORY = "onyx.llm.factory"
def _make_mock_model(
*,
name: str = "gpt-4o",
provider: str = "openai",
provider_id: int = 1,
flow_types: list[str] | None = None,
) -> MagicMock:
model = MagicMock()
model.name = name
model.llm_provider_id = provider_id
model.llm_provider.provider = provider
model.llm_model_flow_types = flow_types or []
return model
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model")
@patch(f"{_FACTORY}.model_supports_image_input", return_value=True)
@patch(f"{_FACTORY}.llm_from_provider")
@patch(f"{_FACTORY}.LLMProviderView")
@patch(f"{_FACTORY}.logger")
def test_logs_when_using_default_vision_model(
mock_logger: MagicMock,
mock_provider_view: MagicMock, # noqa: ARG001
mock_llm_from: MagicMock, # noqa: ARG001
mock_supports: MagicMock, # noqa: ARG001
mock_fetch_default: MagicMock,
mock_session: MagicMock, # noqa: ARG001
) -> None:
mock_fetch_default.return_value = _make_mock_model(name="gpt-4o", provider="azure")
get_default_llm_with_vision()
mock_logger.info.assert_called_once()
log_msg = mock_logger.info.call_args[0][0]
assert "default vision model" in log_msg.lower()
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model")
@patch(f"{_FACTORY}.model_supports_image_input", return_value=False)
@patch(f"{_FACTORY}.fetch_existing_models", return_value=[])
@patch(f"{_FACTORY}.logger")
def test_warns_when_default_model_lacks_vision(
mock_logger: MagicMock,
mock_fetch_models: MagicMock, # noqa: ARG001
mock_supports: MagicMock, # noqa: ARG001
mock_fetch_default: MagicMock,
mock_session: MagicMock, # noqa: ARG001
) -> None:
mock_fetch_default.return_value = _make_mock_model(
name="text-only-model", provider="azure"
)
result = get_default_llm_with_vision()
assert result is None
# Should have warned about the default model not supporting vision
warning_calls = [
call
for call in mock_logger.warning.call_args_list
if "does not support" in str(call)
]
assert len(warning_calls) >= 1
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model", return_value=None)
@patch(f"{_FACTORY}.fetch_existing_models", return_value=[])
@patch(f"{_FACTORY}.logger")
def test_warns_when_no_models_exist(
mock_logger: MagicMock,
mock_fetch_models: MagicMock, # noqa: ARG001
mock_fetch_default: MagicMock, # noqa: ARG001
mock_session: MagicMock, # noqa: ARG001
) -> None:
result = get_default_llm_with_vision()
assert result is None
mock_logger.warning.assert_called_once()
log_msg = mock_logger.warning.call_args[0][0]
assert "no llm models" in log_msg.lower()
@patch(f"{_FACTORY}.get_session_with_current_tenant")
@patch(f"{_FACTORY}.fetch_default_vision_model", return_value=None)
@patch(f"{_FACTORY}.fetch_existing_models")
@patch(f"{_FACTORY}.model_supports_image_input", return_value=False)
@patch(f"{_FACTORY}.LLMProviderView")
@patch(f"{_FACTORY}.logger")
def test_warns_when_no_model_supports_vision(
mock_logger: MagicMock,
mock_provider_view: MagicMock, # noqa: ARG001
mock_supports: MagicMock, # noqa: ARG001
mock_fetch_models: MagicMock,
mock_fetch_default: MagicMock, # noqa: ARG001
mock_session: MagicMock, # noqa: ARG001
) -> None:
mock_fetch_models.return_value = [
_make_mock_model(name="text-model-1", provider="openai"),
_make_mock_model(name="text-model-2", provider="azure", provider_id=2),
]
result = get_default_llm_with_vision()
assert result is None
warning_calls = [
call
for call in mock_logger.warning.call_args_list
if "no vision-capable model" in str(call).lower()
]
assert len(warning_calls) == 1

View File

@@ -0,0 +1,278 @@
"""Unit tests for onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_URL = "https://example.com/hook"
_API_KEY = "secret"
_TIMEOUT = 5.0
def _mock_response(status_code: int) -> MagicMock:
response = MagicMock()
response.status_code = status_code
return response
# ---------------------------------------------------------------------------
# _check_ssrf_safety
# ---------------------------------------------------------------------------
class TestCheckSsrfSafety:
def _call(self, url: str) -> None:
_check_ssrf_safety(url)
# --- scheme checks ---
def test_https_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
@pytest.mark.parametrize(
"url", ["http://example.com/hook", "ftp://example.com/hook"]
)
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@pytest.mark.parametrize(
"ip",
[
pytest.param("127.0.0.1", id="loopback"),
pytest.param("10.0.0.1", id="RFC1918-A"),
pytest.param("172.16.0.1", id="RFC1918-B"),
pytest.param("192.168.1.1", id="RFC1918-C"),
pytest.param("169.254.169.254", id="link-local-IMDS"),
pytest.param("100.64.0.1", id="shared-address-space"),
pytest.param("::1", id="IPv6-loopback"),
pytest.param("fc00::1", id="IPv6-ULA"),
pytest.param("fe80::1", id="IPv6-link-local"),
],
)
def test_private_ip_is_blocked(self, ip: str) -> None:
with (
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
pytest.raises(OnyxError) as exc_info,
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
def test_dns_resolution_failure_raises(self) -> None:
import socket
with (
patch(
"onyx.utils.url.socket.getaddrinfo",
side_effect=socket.gaierror("name not found"),
),
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
# _validate_endpoint
# ---------------------------------------------------------------------------
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(status_code)
)
result = self._call()
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
httpx.ReadTimeout("read timeout"),
httpx.WriteTimeout("write timeout"),
httpx.PoolTimeout("pool timeout"),
],
)
def test_read_write_pool_timeout_returns_timeout(
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
# Covers DNS failures, TLS errors, and other connection-level errors.
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectError("name resolution failed")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
ConnectionRefusedError("refused")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key="mykey")
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key=None)
_, kwargs = mock_post.call_args
assert "Authorization" not in kwargs["headers"]
# ---------------------------------------------------------------------------
# _raise_for_validation_failure
# ---------------------------------------------------------------------------
class TestRaiseForValidationFailure:
@pytest.mark.parametrize(
"status, expected_code",
[
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
],
)
def test_raises_correct_error_code(
self, status: HookValidateStatus, expected_code: OnyxErrorCode
) -> None:
validation = HookValidateResponse(status=status, error_message="some error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.error_code == expected_code
def test_auth_failed_passes_error_message_directly(self) -> None:
validation = HookValidateResponse(
status=HookValidateStatus.auth_failed, error_message="bad credentials"
)
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "bad credentials"
@pytest.mark.parametrize(
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
)
def test_timeout_and_cannot_connect_wrap_error_message(
self, status: HookValidateStatus
) -> None:
validation = HookValidateResponse(status=status, error_message="raw error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "Endpoint validation failed: raw error"
# ---------------------------------------------------------------------------
# HookValidateStatus enum string values (API contract)
# ---------------------------------------------------------------------------
class TestHookValidateStatusValues:
@pytest.mark.parametrize(
"status, expected",
[
(HookValidateStatus.passed, "passed"),
(HookValidateStatus.auth_failed, "auth_failed"),
(HookValidateStatus.timeout, "timeout"),
(HookValidateStatus.cannot_connect, "cannot_connect"),
],
)
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
assert status == expected

View File

@@ -0,0 +1,109 @@
import io
import zipfile
from unittest.mock import MagicMock
from unittest.mock import patch
from zipfile import BadZipFile
import pytest
from fastapi import UploadFile
from starlette.datastructures import Headers
from onyx.configs.constants import FileOrigin
from onyx.server.documents.connector import upload_files
def _create_test_zip() -> bytes:
"""Create a simple in-memory zip file containing two text files."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("file1.txt", "hello")
zf.writestr("file2.txt", "world")
return buf.getvalue()
def _make_upload_file(content: bytes, filename: str, content_type: str) -> UploadFile:
return UploadFile(
file=io.BytesIO(content),
filename=filename,
headers=Headers({"content-type": content_type}),
)
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_zip_with_unzip_true_extracts_files(
mock_get_store: MagicMock,
) -> None:
"""When unzip=True (default), a zip upload is extracted into individual files."""
mock_store = MagicMock()
mock_store.save_file.side_effect = lambda **kwargs: f"id-{kwargs['display_name']}"
mock_get_store.return_value = mock_store
zip_bytes = _create_test_zip()
upload = _make_upload_file(zip_bytes, "test.zip", "application/zip")
result = upload_files([upload], FileOrigin.CONNECTOR)
# Should have extracted the two individual files, not stored the zip itself
assert len(result.file_paths) == 2
assert "id-file1.txt" in result.file_paths
assert "id-file2.txt" in result.file_paths
assert "file1.txt" in result.file_names
assert "file2.txt" in result.file_names
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_zip_with_unzip_false_stores_zip_as_is(
mock_get_store: MagicMock,
) -> None:
"""When unzip=False, the zip file is stored as-is without extraction."""
mock_store = MagicMock()
mock_store.save_file.return_value = "zip-file-id"
mock_get_store.return_value = mock_store
zip_bytes = _create_test_zip()
upload = _make_upload_file(zip_bytes, "site_export.zip", "application/zip")
result = upload_files([upload], FileOrigin.CONNECTOR, unzip=False)
# Should store exactly one file (the zip itself)
assert len(result.file_paths) == 1
assert result.file_paths[0] == "zip-file-id"
assert result.file_names == ["site_export.zip"]
# No zip metadata should be created
assert result.zip_metadata_file_id is None
# Verify the stored content is a valid zip
saved_content: io.BytesIO = mock_store.save_file.call_args[1]["content"]
saved_content.seek(0)
with zipfile.ZipFile(saved_content, "r") as zf:
assert set(zf.namelist()) == {"file1.txt", "file2.txt"}
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_invalid_zip_with_unzip_false_raises(
mock_get_store: MagicMock,
) -> None:
"""An invalid zip is rejected even when unzip=False (validation still runs)."""
mock_get_store.return_value = MagicMock()
bad_zip = _make_upload_file(b"not a zip", "bad.zip", "application/zip")
with pytest.raises(BadZipFile):
upload_files([bad_zip], FileOrigin.CONNECTOR, unzip=False)
@patch("onyx.server.documents.connector.get_default_file_store")
def test_upload_multiple_zips_rejected_when_unzip_false(
mock_get_store: MagicMock,
) -> None:
"""The seen_zip guard rejects a second zip even when unzip=False."""
mock_store = MagicMock()
mock_store.save_file.return_value = "zip-id"
mock_get_store.return_value = mock_store
zip_bytes = _create_test_zip()
zip1 = _make_upload_file(zip_bytes, "a.zip", "application/zip")
zip2 = _make_upload_file(zip_bytes, "b.zip", "application/zip")
with pytest.raises(Exception, match="Only one zip file"):
upload_files([zip1, zip2], FileOrigin.CONNECTOR, unzip=False)

View File

@@ -0,0 +1,208 @@
"""Unit tests for PythonTool file-upload caching.
Verifies that PythonTool reuses code-interpreter file IDs across multiple
run() calls within the same session instead of re-uploading identical content
on every agent loop iteration.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.tools.models import ChatFile
from onyx.tools.models import PythonToolOverrideKwargs
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamResultEvent,
)
from onyx.tools.tool_implementations.python.python_tool import PythonTool
TOOL_MODULE = "onyx.tools.tool_implementations.python.python_tool"
def _make_stream_result() -> StreamResultEvent:
return StreamResultEvent(
exit_code=0,
timed_out=False,
duration_ms=10,
files=[],
)
def _make_tool() -> PythonTool:
emitter = MagicMock()
return PythonTool(tool_id=1, emitter=emitter)
def _make_override(files: list[ChatFile]) -> PythonToolOverrideKwargs:
return PythonToolOverrideKwargs(chat_files=files)
def _run_tool(tool: PythonTool, mock_client: MagicMock, files: list[ChatFile]) -> None:
"""Call tool.run() with a mocked CodeInterpreterClient context manager."""
from onyx.server.query_and_chat.placement import Placement
mock_client.execute_streaming.return_value = iter([_make_stream_result()])
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=mock_client)
ctx.__exit__ = MagicMock(return_value=False)
placement = Placement(turn_index=0, tab_index=0)
override = _make_override(files)
with patch(f"{TOOL_MODULE}.CodeInterpreterClient", return_value=ctx):
tool.run(placement=placement, override_kwargs=override, code="print('hi')")
# ---------------------------------------------------------------------------
# Cache hit: same content uploaded in a second call reuses the file_id
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_same_file_uploaded_only_once_across_two_runs() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.return_value = "file-id-abc"
pptx_content = b"fake pptx bytes"
files = [ChatFile(filename="report.pptx", content=pptx_content)]
_run_tool(tool, client, files)
_run_tool(tool, client, files)
# upload_file should only have been called once across both runs
client.upload_file.assert_called_once_with(pptx_content, "report.pptx")
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_cached_file_id_is_staged_on_second_run() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.return_value = "file-id-abc"
files = [ChatFile(filename="data.pptx", content=b"content")]
_run_tool(tool, client, files)
# On the second run, execute_streaming should still receive the file
client.execute_streaming.return_value = iter([_make_stream_result()])
ctx = MagicMock()
ctx.__enter__ = MagicMock(return_value=client)
ctx.__exit__ = MagicMock(return_value=False)
from onyx.server.query_and_chat.placement import Placement
placement = Placement(turn_index=1, tab_index=0)
with patch(f"{TOOL_MODULE}.CodeInterpreterClient", return_value=ctx):
tool.run(
placement=placement,
override_kwargs=_make_override(files),
code="print('hi')",
)
# The second execute_streaming call should include the file
_, kwargs = client.execute_streaming.call_args
staged_files = kwargs.get("files") or []
assert any(f["file_id"] == "file-id-abc" for f in staged_files)
# ---------------------------------------------------------------------------
# Cache miss: different content triggers a new upload
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_different_file_content_uploaded_separately() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.side_effect = ["file-id-v1", "file-id-v2"]
file_v1 = ChatFile(filename="report.pptx", content=b"version 1")
file_v2 = ChatFile(filename="report.pptx", content=b"version 2")
_run_tool(tool, client, [file_v1])
_run_tool(tool, client, [file_v2])
assert client.upload_file.call_count == 2
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_multiple_distinct_files_each_uploaded_once() -> None:
tool = _make_tool()
client = MagicMock()
client.upload_file.side_effect = ["id-a", "id-b"]
files = [
ChatFile(filename="a.pptx", content=b"aaa"),
ChatFile(filename="b.xlsx", content=b"bbb"),
]
_run_tool(tool, client, files)
_run_tool(tool, client, files)
# Two distinct files — each uploaded exactly once
assert client.upload_file.call_count == 2
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_same_content_different_filename_uploaded_separately() -> None:
# Identical bytes but different names must each get their own upload slot
# so both files appear under their respective paths in the workspace.
tool = _make_tool()
client = MagicMock()
client.upload_file.side_effect = ["id-v1", "id-v2"]
same_bytes = b"shared content"
files = [
ChatFile(filename="report_v1.csv", content=same_bytes),
ChatFile(filename="report_v2.csv", content=same_bytes),
]
_run_tool(tool, client, files)
assert client.upload_file.call_count == 2
# ---------------------------------------------------------------------------
# No cross-instance sharing: a fresh PythonTool re-uploads everything
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_new_tool_instance_re_uploads_file() -> None:
client = MagicMock()
client.upload_file.side_effect = ["id-session-1", "id-session-2"]
files = [ChatFile(filename="deck.pptx", content=b"slide data")]
tool_session_1 = _make_tool()
_run_tool(tool_session_1, client, files)
tool_session_2 = _make_tool()
_run_tool(tool_session_2, client, files)
# Different instances — each uploads independently
assert client.upload_file.call_count == 2
# ---------------------------------------------------------------------------
# Upload failure: failed upload is not cached, retried next run
# ---------------------------------------------------------------------------
@patch(f"{TOOL_MODULE}.CODE_INTERPRETER_BASE_URL", "http://fake:8000")
def test_upload_failure_not_cached() -> None:
tool = _make_tool()
client = MagicMock()
# First call raises, second succeeds
client.upload_file.side_effect = [Exception("network error"), "file-id-ok"]
files = [ChatFile(filename="slides.pptx", content=b"data")]
# First run — upload fails, file is skipped but not cached
_run_tool(tool, client, files)
# Second run — should attempt upload again
_run_tool(tool, client, files)
assert client.upload_file.call_count == 2

93
cubic.yaml Normal file
View File

@@ -0,0 +1,93 @@
# yaml-language-server: $schema=https://cubic.dev/schema/cubic-repository-config.schema.json
version: 1
reviews:
enabled: true
sensitivity: medium
incremental_commits: true
check_drafts: false
custom_instructions: |
Use explicit type annotations for variables to enhance code clarity,
especially when moving type hints around in the code.
Use `contributing_guides/best_practices.md` as core review context.
Prefer consistency with existing patterns, fix issues in code you touch,
avoid tacking new features onto muddy interfaces, fail loudly instead of
silently swallowing errors, keep code strictly typed, preserve clear state
boundaries, remove duplicate or dead logic, break up overly long functions,
avoid hidden import-time side effects, respect module boundaries, and favor
correctness-by-construction over relying on callers to use an API correctly.
Reference these files for additional context:
- `contributing_guides/best_practices.md` — Best practices for contributing to the codebase
- `CLAUDE.md` — Project instructions and coding standards
- `backend/alembic/README.md` — Migration guidance, including multi-tenant migration behavior
- `deployment/helm/charts/onyx/values-lite.yaml` — Lite deployment Helm values and service assumptions
- `deployment/docker_compose/docker-compose.onyx-lite.yml` — Lite deployment Docker Compose overlay and disabled service behavior
ignore:
files:
- greptile.json
- cubic.yaml
custom_rules:
- name: TODO format
description: >
Whenever a TODO is added, there must always be an associated name or
ticket in the style of TODO(name): ... or TODO(1234): ...
- name: Frontend standards
description: >
For frontend changes, enforce all standards described in the
web/AGENTS.md file.
include:
- web/**
- desktop/**
- name: No debugging code
description: >
Remove temporary debugging code before merging to production,
especially tenant-specific debugging logs.
- name: No hardcoded booleans
description: >
When hardcoding a boolean variable to a constant value, remove the
variable entirely and clean up all places where it's used rather than
just setting it to a constant.
- name: Multi-tenant awareness
description: >
Code changes must consider both multi-tenant and single-tenant
deployments. In multi-tenant mode, preserve tenant isolation, ensure
tenant context is propagated correctly, and avoid assumptions that only
hold for a single shared schema or globally shared state. In
single-tenant mode, avoid introducing unnecessary tenant-specific
requirements or cloud-only control-plane dependencies.
- name: Onyx lite compatibility
description: >
Code changes must consider both regular Onyx deployments and Onyx lite
deployments. Lite deployments disable the vector DB, Redis, model
servers, and background workers by default, use PostgreSQL-backed
cache/auth/file storage, and rely on the API server to handle
background work. Do not assume those services are available unless the
code path is explicitly limited to full deployments.
- name: OnyxError over HTTPException
description: >
Never raise HTTPException directly in business code. Use
`raise OnyxError(OnyxErrorCode.XXX, "message")` from
`onyx.error_handling.exceptions`. A global FastAPI exception handler
converts OnyxError into structured JSON responses with
{"error_code": "...", "detail": "..."}. Error codes are defined in
`onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors
with dynamic HTTP status codes, use `status_code_override`:
`raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`.
include:
- backend/**/*.py
issues:
fix_with_cubic_buttons: true
pr_comment_fixes: true
fix_commits_to_pr: true

View File

@@ -12,6 +12,11 @@ services:
api_server:
ports:
- "8080:8080"
deploy:
resources:
limits:
cpus: "${API_SERVER_CPU_LIMIT:-0}"
memory: "${API_SERVER_MEM_LIMIT:-0}"
# Uncomment the block below to enable the MCP server for Onyx.
# mcp_server:

View File

@@ -489,20 +489,18 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -290,25 +290,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -314,21 +314,19 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/sslcerts:/etc/nginx/sslcerts
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
env_file:
- .env.nginx
environment:

View File

@@ -333,25 +333,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -202,20 +202,18 @@ services:
ports:
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -477,7 +477,10 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
# Mount templates read-only; the startup command copies them into
# the writable /etc/nginx/conf.d/ inside the container. This avoids
# "Permission denied" errors on Windows Docker bind mounts.
- ../data/nginx:/nginx-templates:ro
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
# - ../data/certbot/conf:/etc/letsencrypt
# - ../data/certbot/www:/var/www/certbot
@@ -489,12 +492,13 @@ services:
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not receive any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
# PRODUCTION: Change to app.conf.template.prod for production nginx config
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
cache:
image: redis:7.4-alpine

File diff suppressed because it is too large Load Diff

View File

@@ -96,8 +96,8 @@ fi
# When --lite is passed as a flag, lower resource thresholds early (before the
# resource check). When lite is chosen interactively, the thresholds are adjusted
# inside the new-deployment flow, after the resource check has already passed
# with the standard thresholds — which is the safer direction.
# after the resource check has already passed with the standard thresholds —
# which is the safer direction.
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
@@ -110,9 +110,6 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
# Build the -f flags for docker compose.
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
# (used by shutdown/delete-data so users don't need to remember --lite).
# Without the argument, the lite overlay is only included when --lite was
# explicitly passed — preventing install/start from silently staying in
# lite mode just because the file exists on disk from a prior run.
compose_file_args() {
local auto_detect="${1:-false}"
local args="-f docker-compose.yml"
@@ -177,7 +174,7 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
[[ "$NO_PROMPT" = false ]]
}
prompt_or_default() {
@@ -207,6 +204,16 @@ prompt_yn_or_default() {
fi
}
confirm_action() {
local description="$1"
prompt_yn_or_default "Install ${description}? (Y/n) [default: Y] " "Y"
if [[ "$REPLY" =~ ^[Nn] ]]; then
print_warning "Skipping: ${description}"
return 1
fi
return 0
}
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
@@ -395,6 +402,11 @@ fi
if ! command -v docker &> /dev/null; then
if [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; then
print_info "Docker is required but not installed."
if ! confirm_action "Docker Engine"; then
print_error "Docker is required to run Onyx."
exit 1
fi
install_docker_linux
if ! command -v docker &> /dev/null; then
print_error "Docker installation failed."
@@ -411,7 +423,11 @@ if command -v docker &> /dev/null \
&& ! command -v docker-compose &> /dev/null \
&& { [[ "$OSTYPE" == "linux-gnu"* ]] || [[ -n "${WSL_DISTRO_NAME:-}" ]]; }; then
print_info "Docker Compose not found — installing plugin..."
print_info "Docker Compose is required but not installed."
if ! confirm_action "Docker Compose plugin"; then
print_error "Docker Compose is required to run Onyx."
exit 1
fi
COMPOSE_ARCH="$(uname -m)"
COMPOSE_URL="https://github.com/docker/compose/releases/latest/download/docker-compose-linux-${COMPOSE_ARCH}"
COMPOSE_DIR="/usr/local/lib/docker/cli-plugins"
@@ -562,10 +578,31 @@ version_compare() {
# Check Docker daemon
if ! docker info &> /dev/null; then
print_error "Docker daemon is not running. Please start Docker."
exit 1
if [[ "$OSTYPE" == "darwin"* ]]; then
print_info "Docker daemon is not running. Starting Docker Desktop..."
open -a Docker
# Wait up to 120 seconds for Docker to be ready
DOCKER_WAIT=0
DOCKER_MAX_WAIT=120
while ! docker info &> /dev/null; do
if [ $DOCKER_WAIT -ge $DOCKER_MAX_WAIT ]; then
print_error "Docker Desktop did not start within ${DOCKER_MAX_WAIT} seconds."
print_info "Please start Docker Desktop manually and re-run this script."
exit 1
fi
printf "\r\033[KWaiting for Docker Desktop to start... (%ds)" "$DOCKER_WAIT"
sleep 2
DOCKER_WAIT=$((DOCKER_WAIT + 2))
done
echo ""
print_success "Docker Desktop is now running"
else
print_error "Docker daemon is not running. Please start Docker."
exit 1
fi
else
print_success "Docker daemon is running"
fi
print_success "Docker daemon is running"
# Check Docker resources
print_step "Verifying Docker resources"
@@ -705,25 +742,48 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
fi
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo " 2) Standard - Full deployment with search, connectors, and RAG"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
print_info "Selected: Standard mode"
;;
*)
LITE_MODE=true
print_info "Selected: Lite mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Handle lite overlay file based on selected mode
if [[ "$LITE_MODE" = true ]]; then
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
print_warning "Existing lite overlay found but --lite was not passed."
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
LITE_MODE=true
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed lite overlay (switching to standard mode)"
fi
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
@@ -745,6 +805,7 @@ print_success "All configuration files ready"
# Set up deployment configuration
print_step "Setting up deployment configs"
ENV_FILE="${INSTALL_ROOT}/deployment/.env"
ENV_TEMPLATE="${INSTALL_ROOT}/deployment/env.template"
# Check if services are already running
if [ -d "${INSTALL_ROOT}/deployment" ] && [ -f "${INSTALL_ROOT}/deployment/docker-compose.yml" ]; then
# Determine compose command
@@ -785,22 +846,22 @@ if [ -f "$ENV_FILE" ]; then
if [ "$REPLY" = "update" ]; then
print_info "Update selected. Which tag would you like to deploy?"
echo ""
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
if [ "$INCLUDE_CRAFT" = true ]; then
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest version"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -852,45 +913,6 @@ else
print_info "No existing .env file found. Setting up new deployment..."
echo ""
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Standard - Full deployment with search, connectors, and RAG"
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
LITE_MODE=true
print_info "Selected: Lite mode"
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
;;
*)
print_info "Selected: Standard mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
# Validate lite + craft combination (could now be set interactively)
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
# Adjust resource expectations for lite mode
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Ask for version
print_info "Which tag would you like to deploy?"
echo ""
@@ -901,18 +923,18 @@ else
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest tag"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -1070,20 +1092,39 @@ fi
export HOST_PORT=$AVAILABLE_PORT
print_success "Using port $AVAILABLE_PORT for nginx"
# Determine if we're using the latest tag or a craft tag (both should force pull)
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
USE_LATEST=true
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
else
print_info "Using 'latest' tag - will force pull and recreate containers"
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
fi
else
USE_LATEST=false
fi
# For pinned version tags, re-download config files from that tag so the
# compose file matches the images being pulled (the initial download used main).
if [[ "$USE_LATEST" = false ]] && [[ "$USE_LOCAL_FILES" = false ]]; then
PINNED_BASE="https://raw.githubusercontent.com/onyx-dot-app/onyx/${CURRENT_IMAGE_TAG}/deployment"
print_info "Fetching config files matching tag ${CURRENT_IMAGE_TAG}..."
if download_file "${PINNED_BASE}/docker_compose/docker-compose.yml" "${INSTALL_ROOT}/deployment/docker-compose.yml" 2>/dev/null; then
download_file "${PINNED_BASE}/data/nginx/app.conf.template" "${INSTALL_ROOT}/data/nginx/app.conf.template" 2>/dev/null || true
download_file "${PINNED_BASE}/data/nginx/run-nginx.sh" "${INSTALL_ROOT}/data/nginx/run-nginx.sh" 2>/dev/null || true
chmod +x "${INSTALL_ROOT}/data/nginx/run-nginx.sh"
if [[ "$LITE_MODE" = true ]]; then
download_file "${PINNED_BASE}/docker_compose/${LITE_COMPOSE_FILE}" \
"${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" 2>/dev/null || true
fi
print_success "Config files updated to match ${CURRENT_IMAGE_TAG}"
else
print_warning "Tag ${CURRENT_IMAGE_TAG} not found on GitHub — using main branch configs"
fi
fi
# Pull Docker images with reduced output
print_step "Pulling Docker images"
print_info "This may take several minutes depending on your internet connection..."

View File

@@ -19,12 +19,13 @@
## Output template to file and inspect
* cd charts/onyx
* helm template test-output . > test-output.yaml
* helm template test-output . --set auth.opensearch.values.opensearch_admin_password='StrongPassword123!' > test-output.yaml
## Test the entire cluster manually
* cd charts/onyx
* helm install onyx . -n onyx --set postgresql.primary.persistence.enabled=false
* helm install onyx . -n onyx --set postgresql.primary.persistence.enabled=false --set auth.opensearch.values.opensearch_admin_password='StrongPassword123!'
* the postgres flag is to keep the storage ephemeral for testing. You probably don't want to set that in prod.
* the OpenSearch admin password must be set on first install unless you are supplying `auth.opensearch.existingSecret`.
* no flag for ephemeral vespa storage yet, might be good for testing
* kubectl -n onyx port-forward service/onyx-nginx 8080:80
* this will forward the local port 8080 to the installed chart for you to run tests, etc.

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.35
version: 0.4.36
appVersion: latest
annotations:
category: Productivity

View File

@@ -1,6 +1,9 @@
# Values for chart-testing (ct lint/install)
# This file is automatically used by ct when running lint and install commands
auth:
opensearch:
values:
opensearch_admin_password: "placeholder-OpenSearch1!"
userauth:
values:
user_auth_secret: "placeholder-for-ci-testing"

View File

@@ -1177,12 +1177,17 @@ auth:
# Secrets values IF existingSecret is empty. Key here must match the value
# in secretKeys to be used. Values will be base64 encoded in the k8s
# cluster.
# For the bundled OpenSearch chart, the admin password is consumed during
# initial cluster setup. Changing this value later will update Onyx's
# client credentials, but will not rotate the OpenSearch admin password.
# Set this before first install or use existingSecret to preserve the
# current secret on upgrade.
# Password must meet OpenSearch complexity requirements:
# min 8 chars, uppercase, lowercase, digit, and special character.
# CHANGE THIS FOR PRODUCTION.
# Required when auth.opensearch.enabled=true and no existing secret exists.
values:
opensearch_admin_username: "admin"
opensearch_admin_password: "OnyxDev1!"
opensearch_admin_password: ""
userauth:
# -- Used for password reset / verification tokens and OAuth/OIDC state signing.
# Disabled by default to preserve upgrade compatibility for existing Helm customers.

View File

@@ -127,6 +127,7 @@ Inputs (common):
- `name` (default `onyx`), `region` (default `us-west-2`), `tags`
- `postgres_username`, `postgres_password`
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
### `vpc`
- Builds a VPC sized for EKS with multiple private and public subnets

View File

@@ -88,6 +88,8 @@ module "waf" {
tags = local.merged_tags
# WAF configuration with sensible defaults
allowed_ip_cidrs = var.waf_allowed_ip_cidrs
common_rule_set_count_rules = var.waf_common_rule_set_count_rules
rate_limit_requests_per_5_minutes = var.waf_rate_limit_requests_per_5_minutes
api_rate_limit_requests_per_5_minutes = var.waf_api_rate_limit_requests_per_5_minutes
geo_restriction_countries = var.waf_geo_restriction_countries

Some files were not shown because too many files have changed in this diff Show More