mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-24 09:02:43 +00:00
Compare commits
49 Commits
jamison/tr
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06f76bac0c | ||
|
|
e59b0c0d23 | ||
|
|
dfa37cce8b | ||
|
|
6dce6b09e4 | ||
|
|
0eba41c487 | ||
|
|
a426930123 | ||
|
|
73d98c7fa5 | ||
|
|
a096cf3997 | ||
|
|
1e01ff8f10 | ||
|
|
3665bb23c2 | ||
|
|
8ff0d5fc15 | ||
|
|
645d45776a | ||
|
|
fa06e4ebd5 | ||
|
|
2a61e3ce4c | ||
|
|
d8733dd89f | ||
|
|
c651177529 | ||
|
|
3193fe76e4 | ||
|
|
b9025c57f6 | ||
|
|
1b0d62c16e | ||
|
|
7a0c977eb7 | ||
|
|
a9c04dca89 | ||
|
|
8d76a95c86 | ||
|
|
2a69561ec5 | ||
|
|
e1655426d6 | ||
|
|
9634266bc6 | ||
|
|
0c962d882d | ||
|
|
4bf3edf83b | ||
|
|
c48a77c644 | ||
|
|
26d70ab16b | ||
|
|
f8a2f3ac93 | ||
|
|
5186356a26 | ||
|
|
7b826e2a4e | ||
|
|
c175dc8f6a | ||
|
|
aa11813cc0 | ||
|
|
6235f49b49 | ||
|
|
fd6a110794 | ||
|
|
bd42c459d6 | ||
|
|
aede532e63 | ||
|
|
068ac543ad | ||
|
|
30e7a831a5 | ||
|
|
276261c96d | ||
|
|
205f1410e4 | ||
|
|
a93d154c27 | ||
|
|
1361879bd0 | ||
|
|
c58cc320b2 | ||
|
|
461350958a | ||
|
|
50dde0be1a | ||
|
|
199e1df453 | ||
|
|
996b674840 |
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -307,7 +307,7 @@ jobs:
|
||||
xdg-utils
|
||||
|
||||
- name: setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6.2.0
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6.3.0
|
||||
with:
|
||||
node-version: 24
|
||||
package-manager-cache: false
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f
|
||||
with:
|
||||
node-version: 24
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
|
||||
6
.github/workflows/pr-playwright-tests.yml
vendored
6
.github/workflows/pr-playwright-tests.yml
vendored
@@ -272,7 +272,7 @@ jobs:
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -614,7 +614,7 @@ jobs:
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
uses: docker/bake-action@82490499d2e5613fcead7e128237ef0b0ea210f7 # ratchet:docker/bake-action@v7.0.0
|
||||
env:
|
||||
TAG: model-server-${{ github.run_id }}
|
||||
with:
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
2
.github/workflows/preview.yml
vendored
2
.github/workflows/preview.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/storybook-deploy.yml
vendored
2
.github/workflows/storybook-deploy.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
279
AGENTS.md
279
AGENTS.md
@@ -167,284 +167,7 @@ web/
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ RUN apt-get update && \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
# Install procps so kubernetes exec sessions can use ps aux for debugging
|
||||
procps \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""rename persona is_visible to is_listed and featured to is_featured
|
||||
|
||||
Revision ID: b728689f45b1
|
||||
Revises: 689433b0d8de
|
||||
Create Date: 2026-03-23 12:36:26.607305
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b728689f45b1"
|
||||
down_revision = "689433b0d8de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("persona", "is_visible", new_column_name="is_listed")
|
||||
op.alter_column("persona", "featured", new_column_name="is_featured")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("persona", "is_listed", new_column_name="is_visible")
|
||||
op.alter_column("persona", "is_featured", new_column_name="featured")
|
||||
@@ -36,6 +36,56 @@ TABLES_WITH_USER_ID = [
|
||||
]
|
||||
|
||||
|
||||
def _dedupe_null_notifications(connection: sa.Connection) -> None:
|
||||
# Multiple NULL-owned notifications can exist because the unique index treats
|
||||
# NULL user_id values as distinct. Before migrating them to the anonymous
|
||||
# user, collapse duplicates and remove rows that would conflict with an
|
||||
# already-existing anonymous notification.
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
WITH ranked_null_notifications AS (
|
||||
SELECT
|
||||
id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY notif_type, COALESCE(additional_data, '{}'::jsonb)
|
||||
ORDER BY first_shown DESC, last_shown DESC, id DESC
|
||||
) AS row_num
|
||||
FROM notification
|
||||
WHERE user_id IS NULL
|
||||
)
|
||||
DELETE FROM notification
|
||||
WHERE id IN (
|
||||
SELECT id
|
||||
FROM ranked_null_notifications
|
||||
WHERE row_num > 1
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(f"Deleted {result.rowcount} duplicate NULL-owned notifications")
|
||||
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM notification AS null_owned
|
||||
USING notification AS anonymous_owned
|
||||
WHERE null_owned.user_id IS NULL
|
||||
AND anonymous_owned.user_id = :user_id
|
||||
AND null_owned.notif_type = anonymous_owned.notif_type
|
||||
AND COALESCE(null_owned.additional_data, '{}'::jsonb) =
|
||||
COALESCE(anonymous_owned.additional_data, '{}'::jsonb)
|
||||
"""
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(
|
||||
f"Deleted {result.rowcount} NULL-owned notifications that conflict with existing anonymous-owned notifications"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Create the anonymous user for anonymous access feature.
|
||||
@@ -65,7 +115,12 @@ def upgrade() -> None:
|
||||
|
||||
# Migrate any remaining user_id=NULL records to anonymous user
|
||||
for table in TABLES_WITH_USER_ID:
|
||||
try:
|
||||
# Dedup notifications outside the savepoint so deletions persist
|
||||
# even if the subsequent UPDATE rolls back
|
||||
if table == "notification":
|
||||
_dedupe_null_notifications(connection)
|
||||
|
||||
with connection.begin_nested():
|
||||
# Exclude public credential (id=0) which must remain user_id=NULL
|
||||
# Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL
|
||||
# Exclude builtin personas (builtin_persona=True) which must remain user_id=NULL
|
||||
@@ -80,6 +135,7 @@ def upgrade() -> None:
|
||||
condition = "user_id IS NULL AND is_public = false"
|
||||
else:
|
||||
condition = "user_id IS NULL"
|
||||
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
@@ -92,19 +148,19 @@ def upgrade() -> None:
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(f"Updated {result.rowcount} rows in {table} to anonymous user")
|
||||
except Exception as e:
|
||||
print(f"Skipping {table}: {e}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Set anonymous user's records back to NULL and delete the anonymous user.
|
||||
|
||||
Note: Duplicate NULL-owned notifications removed during upgrade are not restored.
|
||||
"""
|
||||
connection = op.get_bind()
|
||||
|
||||
# Set records back to NULL
|
||||
for table in TABLES_WITH_USER_ID:
|
||||
try:
|
||||
with connection.begin_nested():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
@@ -115,8 +171,6 @@ def downgrade() -> None:
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Delete the anonymous user
|
||||
connection.execute(
|
||||
|
||||
@@ -25,9 +25,6 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
@@ -58,7 +55,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
@@ -74,9 +71,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
num_available_tenants = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
num_minimum_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
if num_available_tenants < num_minimum_available_tenants:
|
||||
@@ -98,7 +93,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
lock_check.release()
|
||||
try:
|
||||
lock_check.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release check lock (likely expired), continuing"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> None:
|
||||
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
@@ -185,4 +185,9 @@ def pre_provision_tenant() -> None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
lock_provision.release()
|
||||
try:
|
||||
lock_provision.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release provision lock (likely expired), continuing"
|
||||
)
|
||||
|
||||
@@ -157,7 +157,11 @@ def fetch_logo_helper(db_session: Session) -> Response: # noqa: ARG001
|
||||
detail="No logo file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
return Response(
|
||||
content=onyx_file.data,
|
||||
media_type=onyx_file.mime_type,
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response: # noqa: ARG001
|
||||
|
||||
@@ -178,7 +178,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
is_featured=persona.is_featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -80,15 +80,45 @@ def capture_and_sync_with_alternate_posthog(
|
||||
logger.error(f"Error identifying cloud posthog user: {e}")
|
||||
|
||||
|
||||
def alias_user(distinct_id: str, anonymous_id: str) -> None:
|
||||
"""Link an anonymous distinct_id to an identified user, merging person profiles.
|
||||
|
||||
No-ops when the IDs match (e.g. returning users whose PostHog cookie
|
||||
already contains their identified user ID).
|
||||
"""
|
||||
if not posthog or anonymous_id == distinct_id:
|
||||
return
|
||||
|
||||
try:
|
||||
posthog.alias(previous_id=anonymous_id, distinct_id=distinct_id)
|
||||
posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error aliasing PostHog user: {e}")
|
||||
|
||||
|
||||
def get_anon_id_from_request(request: Any) -> str | None:
|
||||
"""Extract the anonymous distinct_id from the app PostHog cookie on a request."""
|
||||
if not POSTHOG_API_KEY:
|
||||
return None
|
||||
|
||||
cookie_name = f"ph_{POSTHOG_API_KEY}_posthog"
|
||||
if (cookie_value := request.cookies.get(cookie_name)) and (
|
||||
parsed := parse_posthog_cookie(cookie_value)
|
||||
):
|
||||
return parsed.get("distinct_id")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_marketing_posthog_cookie_name() -> str | None:
|
||||
if not MARKETING_POSTHOG_API_KEY:
|
||||
return None
|
||||
return f"onyx_custom_ph_{MARKETING_POSTHOG_API_KEY}_posthog"
|
||||
|
||||
|
||||
def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
def parse_posthog_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Parse the URL-encoded JSON marketing cookie.
|
||||
Parse a URL-encoded JSON PostHog cookie
|
||||
|
||||
Expected format (URL-encoded):
|
||||
{"distinct_id":"...", "featureFlags":{"landing_page_variant":"..."}, ...}
|
||||
@@ -102,7 +132,7 @@ def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
cookie_data = json.loads(decoded_cookie)
|
||||
|
||||
distinct_id = cookie_data.get("distinct_id")
|
||||
if not distinct_id:
|
||||
if not distinct_id or not isinstance(distinct_id, str):
|
||||
return None
|
||||
|
||||
return cookie_data
|
||||
|
||||
@@ -135,6 +135,8 @@ from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_alias
|
||||
from onyx.utils.telemetry import mt_cloud_get_anon_id
|
||||
from onyx.utils.telemetry import mt_cloud_identify
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -251,18 +253,12 @@ def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -279,12 +275,9 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"This workspace is invite-only. Please ask your admin to invite you.",
|
||||
)
|
||||
|
||||
|
||||
@@ -294,48 +287,47 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
def verify_email_domain(email: str, *, is_registration: bool = False) -> None:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid")
|
||||
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
local_part = local_part.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Please use @gmail.com instead of @googlemail.com.",
|
||||
)
|
||||
|
||||
# Only block dotted Gmail on new signups — existing users must still be
|
||||
# able to sign in with the address they originally registered with.
|
||||
if is_registration and domain == "gmail.com" and "." in local_part:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Gmail addresses with '.' are not allowed. Please use your base email address.",
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Email addresses with '+' are not allowed. Please use your base email address.",
|
||||
)
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email domain is not valid",
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email domain is not valid")
|
||||
|
||||
|
||||
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
@@ -351,7 +343,7 @@ def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
)(db_session, seats_needed=seats_needed)
|
||||
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
raise OnyxError(OnyxErrorCode.SEAT_LIMIT_EXCEEDED, result.error_message)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
@@ -404,10 +396,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
captcha_token or "", expected_action="signup"
|
||||
)
|
||||
except CaptchaVerificationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": str(e)},
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
@@ -417,13 +406,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Check for disposable emails BEFORE provisioning tenant
|
||||
# This prevents creating tenants for throwaway email addresses
|
||||
try:
|
||||
verify_email_domain(user_create.email)
|
||||
except HTTPException as e:
|
||||
verify_email_domain(user_create.email, is_registration=True)
|
||||
except OnyxError as e:
|
||||
# Log blocked disposable email attempts
|
||||
if (
|
||||
e.status_code == status.HTTP_400_BAD_REQUEST
|
||||
and "Disposable email" in str(e.detail)
|
||||
):
|
||||
if "Disposable email" in e.detail:
|
||||
domain = (
|
||||
user_create.email.split("@")[-1]
|
||||
if "@" in user_create.email
|
||||
@@ -567,9 +553,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_featured.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.is_listed.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
)
|
||||
.order_by(
|
||||
@@ -697,6 +683,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
verify_email_domain(account_email, is_registration=True)
|
||||
|
||||
# Check seat availability before creating (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
@@ -795,6 +783,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
logger.exception("Error deleting anonymous user cookie")
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
# Link the anonymous PostHog session to the identified user so that
|
||||
# pre-login session recordings and events merge into one person profile.
|
||||
if anon_id := mt_cloud_get_anon_id(request):
|
||||
mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id)
|
||||
|
||||
mt_cloud_identify(
|
||||
distinct_id=str(user.id),
|
||||
properties={"email": user.email, "tenant_id": tenant_id},
|
||||
@@ -818,6 +812,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
# Link the anonymous PostHog session to the identified user so
|
||||
# that pre-signup session recordings merge into one person profile.
|
||||
if anon_id := mt_cloud_get_anon_id(request):
|
||||
mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id)
|
||||
|
||||
# Ensure a PostHog person profile exists for this user.
|
||||
mt_cloud_identify(
|
||||
distinct_id=str(user.id),
|
||||
@@ -846,9 +845,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
attribute="get_marketing_posthog_cookie_name",
|
||||
noop_return_value=None,
|
||||
)
|
||||
parse_marketing_cookie = fetch_ee_implementation_or_noop(
|
||||
parse_posthog_cookie = fetch_ee_implementation_or_noop(
|
||||
module="onyx.utils.posthog_client",
|
||||
attribute="parse_marketing_cookie",
|
||||
attribute="parse_posthog_cookie",
|
||||
noop_return_value=None,
|
||||
)
|
||||
capture_and_sync_with_alternate_posthog = fetch_ee_implementation_or_noop(
|
||||
@@ -862,7 +861,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
and user_count is not None
|
||||
and (marketing_cookie_name := get_marketing_posthog_cookie_name())
|
||||
and (marketing_cookie_value := request.cookies.get(marketing_cookie_name))
|
||||
and (parsed_cookie := parse_marketing_cookie(marketing_cookie_value))
|
||||
and (parsed_cookie := parse_posthog_cookie(marketing_cookie_value))
|
||||
):
|
||||
marketing_anonymous_id = parsed_cookie["distinct_id"]
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if not MULTI_TENANT and HOOK_ENABLED:
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
|
||||
@@ -30,6 +30,8 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import plaintext_file_name_for_id
|
||||
from onyx.file_store.utils import store_plaintext
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
@@ -289,6 +291,33 @@ def process_kg_commands(
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
def _get_or_extract_plaintext(
|
||||
file_id: str,
|
||||
extract_fn: Callable[[], str],
|
||||
) -> str:
|
||||
"""Load cached plaintext for a file, or extract and store it.
|
||||
|
||||
Tries to read pre-stored plaintext from the file store. On a miss,
|
||||
calls extract_fn to produce the text, then stores the result so
|
||||
future calls skip the expensive extraction.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
plaintext_key = plaintext_file_name_for_id(file_id)
|
||||
|
||||
# Try cached plaintext first.
|
||||
try:
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
if content_text:
|
||||
store_plaintext(file_id, content_text)
|
||||
return content_text
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -303,12 +332,23 @@ def load_chat_file(
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
try:
|
||||
content_text = extract_file_text(
|
||||
file_id = file_descriptor["id"]
|
||||
|
||||
def _extract() -> str:
|
||||
return extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
# Use the user_file_id as cache key when available (matches what
|
||||
# the celery indexing worker stores), otherwise fall back to the
|
||||
# file store id (covers code-interpreter-generated files, etc.).
|
||||
user_file_id_str = file_descriptor.get("user_file_id")
|
||||
cache_key = user_file_id_str or file_id
|
||||
|
||||
try:
|
||||
content_text = _get_or_extract_plaintext(cache_key, _extract)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
|
||||
@@ -177,8 +177,8 @@ class ExtractedContextFiles(BaseModel):
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
project_id_filter: int | None
|
||||
persona_id_filter: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
|
||||
@@ -399,13 +399,13 @@ def determine_search_params(
|
||||
"""
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
persona_id_filter = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
project_id_filter = project_id
|
||||
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
@@ -418,8 +418,8 @@ def determine_search_params(
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
@@ -711,8 +711,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
|
||||
@@ -88,8 +88,9 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
IFRAME_TEXT_LENGTH_THRESHOLD = 700
|
||||
# Message indicating JavaScript is disabled, which often appears when scraping fails
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
# Grace period after page navigation to allow bot-detection challenges to complete
|
||||
BOT_DETECTION_GRACE_PERIOD_MS = 5000
|
||||
# Grace period after page navigation to allow bot-detection challenges
|
||||
# and SPA content rendering to complete
|
||||
PAGE_RENDER_TIMEOUT_MS = 5000
|
||||
|
||||
# Define common headers that mimic a real browser
|
||||
DEFAULT_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"
|
||||
@@ -547,7 +548,15 @@ class WebConnector(LoadConnector):
|
||||
)
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
|
||||
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
|
||||
|
||||
# Wait for network activity to settle so SPAs that fetch content
|
||||
# asynchronously after the initial JS bundle have time to render.
|
||||
try:
|
||||
# A bit of extra time to account for long-polling, websockets, etc.
|
||||
page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified") if page_response else None
|
||||
@@ -576,7 +585,7 @@ class WebConnector(LoadConnector):
|
||||
# (e.g., CloudFlare protection keeps making requests)
|
||||
try:
|
||||
page.wait_for_load_state(
|
||||
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
|
||||
"networkidle", timeout=PAGE_RENDER_TIMEOUT_MS
|
||||
)
|
||||
except TimeoutError:
|
||||
# If networkidle times out, just give it a moment for content to render
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -70,9 +69,13 @@ class BaseFilters(BaseModel):
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
# Scopes search to user files tagged with a given project/persona in Vespa.
|
||||
# These are NOT simply the IDs of the current project or persona — they are
|
||||
# only set when the persona's/project's user files overflowed the LLM
|
||||
# context window and must be searched via vector DB instead of being loaded
|
||||
# directly into the prompt.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -39,9 +38,8 @@ logger = setup_logger()
|
||||
def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session | None = None,
|
||||
@@ -97,16 +95,6 @@ def _build_index_filters(
|
||||
if not source_filter and detected_source_filter:
|
||||
source_filter = detected_source_filter
|
||||
|
||||
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
|
||||
# source type is included in the filter, otherwise user files will be excluded!
|
||||
if user_file_ids and source_filter:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# Add user_file to the source filter if not already present
|
||||
if DocumentSource.USER_FILE not in source_filter:
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
@@ -117,9 +105,8 @@ def _build_index_filters(
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -265,19 +252,16 @@ def search_pipeline(
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't fit
|
||||
# in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None = None,
|
||||
persona_id_filter: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
)
|
||||
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
@@ -302,9 +286,8 @@ def search_pipeline(
|
||||
filters = _build_index_filters(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -110,7 +110,6 @@ def search_chunks(
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
|
||||
@@ -583,6 +583,67 @@ def get_latest_index_attempt_for_cc_pair_id(
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_latest_successful_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool = False,
|
||||
) -> IndexAttempt | None:
|
||||
"""Returns the most recent successful index attempt for the given cc pair,
|
||||
filtered to the current (or future) search settings.
|
||||
Uses MAX(id) semantics to match get_latest_index_attempts_by_status."""
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
|
||||
),
|
||||
)
|
||||
.join(SearchSettings)
|
||||
.where(SearchSettings.status == status)
|
||||
.order_by(desc(IndexAttempt.id))
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_latest_successful_index_attempts_parallel(
|
||||
secondary_index: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
"""Batch version: returns the latest successful index attempt per cc pair.
|
||||
Covers both SUCCESS and COMPLETED_WITH_ERRORS (matching is_successful())."""
|
||||
model_status = (
|
||||
IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
latest_ids = (
|
||||
select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.where(
|
||||
SearchSettings.status == model_status,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
|
||||
),
|
||||
)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(IndexAttempt).join(
|
||||
latest_ids,
|
||||
(
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== latest_ids.c.connector_credential_pair_id
|
||||
)
|
||||
& (IndexAttempt.id == latest_ids.c.max_id),
|
||||
)
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def count_index_attempts_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
|
||||
@@ -3467,9 +3467,9 @@ class Persona(Base):
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is listed in user-facing agent lists
|
||||
is_listed: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
# higher priority personas are displayed first, ties are resolved by the ID,
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
|
||||
@@ -126,7 +126,7 @@ def _add_user_filters(
|
||||
else:
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
Persona.is_listed == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
@@ -260,7 +260,7 @@ def create_update_persona(
|
||||
|
||||
try:
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
if create_persona_request.is_featured:
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
@@ -300,7 +300,7 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
featured=create_persona_request.featured,
|
||||
is_featured=create_persona_request.is_featured,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -910,11 +910,11 @@ def upsert_persona(
|
||||
uploaded_image_id: str | None = None,
|
||||
icon_name: str | None = None,
|
||||
display_priority: int | None = None,
|
||||
is_visible: bool = True,
|
||||
is_listed: bool = True,
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
featured: bool | None = None,
|
||||
is_featured: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
@@ -1037,13 +1037,13 @@ def upsert_persona(
|
||||
if remove_image or uploaded_image_id:
|
||||
existing_persona.uploaded_image_id = uploaded_image_id
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.is_listed = is_listed
|
||||
existing_persona.search_start_date = search_start_date
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
existing_persona.is_featured = (
|
||||
is_featured if is_featured is not None else existing_persona.is_featured
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1109,9 +1109,9 @@ def upsert_persona(
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
icon_name=icon_name,
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
is_listed=is_listed,
|
||||
search_start_date=search_start_date,
|
||||
featured=(featured if featured is not None else False),
|
||||
is_featured=(is_featured if is_featured is not None else False),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1158,7 +1158,7 @@ def delete_old_default_personas(
|
||||
|
||||
def update_persona_featured(
|
||||
persona_id: int,
|
||||
featured: bool,
|
||||
is_featured: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1166,13 +1166,13 @@ def update_persona_featured(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.featured = featured
|
||||
persona.is_featured = is_featured
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_visibility(
|
||||
persona_id: int,
|
||||
is_visible: bool,
|
||||
is_listed: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1180,7 +1180,7 @@ def update_persona_visibility(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_visible = is_visible
|
||||
persona.is_listed = is_listed
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ def create_slack_channel_persona(
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
featured=False,
|
||||
is_featured=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
|
||||
@@ -10,8 +10,8 @@ How `IndexFilters` fields combine into the final query filter. Applies to both V
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
| **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
@@ -31,12 +31,22 @@ AND time >= cutoff -- if set
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### Primary vs additive triggers
|
||||
|
||||
- **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit
|
||||
knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is
|
||||
NOT the raw ID of the persona being used — it is only set when the persona's
|
||||
user files overflowed the LLM context window.
|
||||
- **`project_id_filter`** is **additive**. It widens an existing scope to include project
|
||||
files but never restricts on its own — a chat inside a project should still search
|
||||
team knowledge when no other knowledge is attached.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
- `project_id_filter` is ignored — it never restricts on its own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
@@ -44,39 +54,40 @@ When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_no
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
-- Only persona user files (overflowed context)
|
||||
AND (personas contains 42)
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
-- Document sets + persona user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
```
|
||||
|
||||
-- User files + project files overflowed
|
||||
### Explicit knowledge + overflowing project files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter:
|
||||
|
||||
```
|
||||
-- Document sets + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
document_sets contains "Engineering"
|
||||
OR user_project contains 7
|
||||
)
|
||||
|
||||
-- Persona user files + project files (won't happen in practice;
|
||||
-- custom personas ignore project files per the precedence rule)
|
||||
AND (
|
||||
personas contains 42
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
### Only project_id_filter (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
@@ -91,11 +102,10 @@ AND (acl contains ...)
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `persona_id_filter` | `personas` | `array<int>` | Persona tag for overflowing user files (**primary** trigger) |
|
||||
| `project_id_filter` | `user_project` | `array<int>` | Project tag for overflowing project files (**additive** only) |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
@@ -1062,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
f"Body: {get_new_body_without_vectors(body)}\n"
|
||||
f"Search pipeline ID: {search_pipeline_id}\n"
|
||||
f"Phase took: {phase_took}\n"
|
||||
f"Profile: {profile}\n"
|
||||
f"Profile: {json.dumps(profile, indent=2)}\n"
|
||||
)
|
||||
if timed_out:
|
||||
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
|
||||
|
||||
@@ -950,7 +950,86 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
# "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_embedding,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
@@ -219,9 +218,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -286,9 +284,8 @@ class DocumentQuery:
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
project_id_filter=None,
|
||||
persona_id_filter=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -356,9 +353,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -404,12 +400,168 @@ class DocumentQuery:
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# Explain is for scoring breakdowns.
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
def get_keyword_search_query(
|
||||
query_text: str,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final keyword search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the keyword search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final keyword search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
keyword_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
keyword_search_query = (
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text, search_filters=keyword_search_filters
|
||||
)
|
||||
)
|
||||
|
||||
final_keyword_search_query: dict[str, Any] = {
|
||||
"query": keyword_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_keyword_search_query["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_keyword_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_keyword_search_query["explain"] = True
|
||||
|
||||
return final_keyword_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_semantic_search_query(
|
||||
query_embedding: list[float],
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final semantic search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the semantic search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final semantic search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
semantic_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
semantic_search_query = (
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_embedding,
|
||||
vector_candidates=num_hits,
|
||||
search_filters=semantic_search_filters,
|
||||
)
|
||||
)
|
||||
|
||||
final_semantic_search_query: dict[str, Any] = {
|
||||
"query": semantic_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_semantic_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_semantic_search_query["explain"] = True
|
||||
|
||||
return final_semantic_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_random_search_query(
|
||||
tenant_state: TenantState,
|
||||
@@ -433,9 +585,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -581,8 +732,9 @@ class DocumentQuery:
|
||||
def _get_content_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
query = {
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
@@ -591,11 +743,19 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
|
||||
"bool": {"filter": search_filters}
|
||||
}
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_title_content_combined_keyword_search_query(
|
||||
query_text: str,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
@@ -636,10 +796,19 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
],
|
||||
# Ensure at least one term from the query is present in the
|
||||
# document. This defaults to 1, unless a filter or must clause
|
||||
# is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["bool"]["filter"] = search_filters
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
@@ -648,9 +817,8 @@ class DocumentQuery:
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -681,12 +849,12 @@ class DocumentQuery:
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
project_id_filter: If not None, only documents with this project ID
|
||||
in user projects will be retrieved. Additive — only applied
|
||||
when a knowledge scope already exists.
|
||||
persona_id_filter: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved. Primary — creates
|
||||
a knowledge scope on its own.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -703,10 +871,6 @@ class DocumentQuery:
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
attached_document_ids: Document IDs explicitly attached to the
|
||||
assistant. If provided along with hierarchy_node_ids, documents
|
||||
matching EITHER criteria will be retrieved (OR logic).
|
||||
@@ -767,15 +931,6 @@ class DocumentQuery:
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
@@ -876,14 +1031,17 @@ class DocumentQuery:
|
||||
# assistant can see. When none are set the assistant searches
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user files
|
||||
# findable but must NOT trigger the restriction on their own (an agent
|
||||
# with no explicit knowledge should search everything).
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
or persona_id_filter is not None
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
@@ -898,23 +1056,17 @@ class DocumentQuery:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user files, but
|
||||
# only when an explicit restriction is already in effect.
|
||||
if project_id is not None:
|
||||
if persona_id_filter is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
_get_persona_filter(persona_id_filter)
|
||||
)
|
||||
if persona_id is not None:
|
||||
if project_id_filter is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
_get_user_project_filter(project_id_filter)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
@@ -932,8 +1084,6 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
@@ -199,31 +199,29 @@ def build_vespa_filters(
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
# Knowledge scope: explicit knowledge attachments restrict what an
|
||||
# assistant can see. When none are set, the assistant can see
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter))
|
||||
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
# project_id_filter only widens an existing scope.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
_append(
|
||||
knowledge_scope_parts,
|
||||
_build_user_project_filter(filters.project_id_filter),
|
||||
)
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
|
||||
@@ -88,6 +88,7 @@ class OnyxErrorCode(Enum):
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
|
||||
@@ -38,17 +38,7 @@ def get_federated_retrieval_functions(
|
||||
source_types: list[DocumentSource] | None,
|
||||
document_set_names: list[str] | None,
|
||||
slack_context: SlackContext | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
) -> list[FederatedRetrievalInfo]:
|
||||
# When User Knowledge (user files) is the only knowledge source enabled,
|
||||
# skip federated connectors entirely. User Knowledge mode means the agent
|
||||
# should ONLY use uploaded files, not team connectors like Slack.
|
||||
if user_file_ids and not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping all federated connectors: User Knowledge mode enabled "
|
||||
f"with {len(user_file_ids)} user files and no document sets"
|
||||
)
|
||||
return []
|
||||
|
||||
# Check for Slack bot context first (regardless of user_id)
|
||||
if slack_context:
|
||||
|
||||
@@ -23,45 +23,55 @@ from onyx.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
def plaintext_file_name_for_id(file_id: str) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a file."""
|
||||
return f"plaintext_{file_id}"
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
|
||||
"""
|
||||
Store plaintext content for a user file in the file store.
|
||||
Store plaintext content for a file in the file store.
|
||||
|
||||
Args:
|
||||
user_file_id: The ID of the user file
|
||||
file_id: The ID of the file (user_file or artifact_file)
|
||||
plaintext_content: The plaintext content to store
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
plaintext_file_name = plaintext_file_name_for_id(file_id)
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
display_name=f"Plaintext for {file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Convenience wrappers for callers that use user-file UUIDs ---
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return plaintext_file_name_for_id(str(user_file_id))
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
|
||||
return store_plaintext(str(user_file_id), plaintext_content)
|
||||
|
||||
|
||||
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
|
||||
"""Load a file directly from the file store using its file_record ID.
|
||||
|
||||
|
||||
330
backend/onyx/hooks/executor.py
Normal file
330
backend/onyx/hooks/executor.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
"""No active hook configured for this hook point."""
|
||||
|
||||
|
||||
class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
)
|
||||
return outcome.response_payload
|
||||
@@ -42,12 +42,8 @@ class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
fail_strategy: HookFailStrategy | None = None
|
||||
timeout_seconds: float | None = Field(default=None, gt=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
@@ -60,6 +56,14 @@ class HookUpdateRequest(BaseModel):
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
|
||||
raise ValueError(
|
||||
"fail_strategy cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
|
||||
raise ValueError(
|
||||
"timeout_seconds cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -90,38 +94,28 @@ class HookResponse(BaseModel):
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
is_reachable: bool | None
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateStatus(str, Enum):
|
||||
passed = "passed" # server responded (any status except 401/403)
|
||||
auth_failed = "auth_failed" # server responded with 401 or 403
|
||||
timeout = (
|
||||
"timeout" # TCP connected, but read/write timed out (server exists but slow)
|
||||
)
|
||||
cannot_connect = "cannot_connect" # could not connect to the server
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
success: bool
|
||||
status: HookValidateStatus
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
class HookExecutionRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -13,22 +14,25 @@ _REQUIRED_ATTRS = (
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
"payload_model",
|
||||
"response_model",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec(ABC):
|
||||
class HookPointSpec:
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
This is NOT a regular class meant for direct instantiation by callers.
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
|
||||
should ever create instances directly — use get_hook_point_spec() or
|
||||
get_all_specs() from the registry instead.
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
|
||||
get_hook_point_spec() or get_all_specs() from the registry over direct
|
||||
instantiation.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
payload_model and response_model must be Pydantic BaseModel subclasses;
|
||||
input_schema and output_schema are derived from them automatically.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
@@ -39,21 +43,33 @@ class HookPointSpec(ABC):
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
payload_model: ClassVar[type[BaseModel]]
|
||||
response_model: ClassVar[type[BaseModel]]
|
||||
|
||||
# Computed once at class definition time from payload_model / response_model.
|
||||
input_schema: ClassVar[dict[str, Any]]
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
for attr in ("payload_model", "response_model"):
|
||||
val = getattr(cls, attr, None)
|
||||
if val is None or not (
|
||||
isinstance(val, type) and issubclass(val, BaseModel)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
|
||||
)
|
||||
cls.input_schema = cls.payload_model.model_json_schema()
|
||||
cls.output_schema = cls.response_model.model_json_schema()
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
@@ -18,12 +27,5 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -1,10 +1,39 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(description="The raw query string exactly as the user typed it.")
|
||||
user_email: str | None = Field(
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
query: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
@@ -37,47 +66,5 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
"chat_session_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email", "chat_session_id"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
5
backend/onyx/hooks/utils.py
Normal file
5
backend/onyx/hooks/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -453,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -43,6 +43,9 @@ from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import (
|
||||
get_latest_successful_index_attempt_for_cc_pair_id,
|
||||
)
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -190,6 +193,11 @@ def get_cc_pair_full_info(
|
||||
only_finished=False,
|
||||
)
|
||||
|
||||
latest_successful_attempt = get_latest_successful_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# Get latest permission sync attempt for status
|
||||
latest_permission_sync_attempt = None
|
||||
if cc_pair.access_type == AccessType.SYNC:
|
||||
@@ -207,6 +215,11 @@ def get_cc_pair_full_info(
|
||||
cc_pair_id=cc_pair_id,
|
||||
),
|
||||
last_index_attempt=latest_attempt,
|
||||
last_successful_index_time=(
|
||||
latest_successful_attempt.time_started
|
||||
if latest_successful_attempt
|
||||
else None
|
||||
),
|
||||
latest_deletion_attempt=get_deletion_attempt_snapshot(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
|
||||
@@ -3,6 +3,7 @@ import math
|
||||
import mimetypes
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -109,6 +110,9 @@ from onyx.db.federated import fetch_all_federated_connectors_parallel
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_parallel
|
||||
from onyx.db.index_attempt import (
|
||||
get_latest_successful_index_attempts_parallel,
|
||||
)
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import FederatedConnector
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -1158,21 +1162,26 @@ def get_connector_indexing_status(
|
||||
),
|
||||
(),
|
||||
),
|
||||
# Get most recent successful index attempts
|
||||
(
|
||||
lambda: get_latest_successful_index_attempts_parallel(
|
||||
request.secondary_index,
|
||||
),
|
||||
(),
|
||||
),
|
||||
]
|
||||
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
# For Admin users, we already got all the cc pair in editable_cc_pairs
|
||||
# its not needed to get them again
|
||||
(
|
||||
editable_cc_pairs,
|
||||
federated_connectors,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
latest_successful_index_attempts,
|
||||
) = run_functions_tuples_in_parallel(parallel_functions)
|
||||
non_editable_cc_pairs = []
|
||||
else:
|
||||
parallel_functions.append(
|
||||
# Get non-editable connector/credential pairs
|
||||
(
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, False, None, True, True, False, True, request.source
|
||||
@@ -1186,6 +1195,7 @@ def get_connector_indexing_status(
|
||||
federated_connectors,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
latest_successful_index_attempts,
|
||||
non_editable_cc_pairs,
|
||||
) = run_functions_tuples_in_parallel(parallel_functions)
|
||||
|
||||
@@ -1197,6 +1207,9 @@ def get_connector_indexing_status(
|
||||
latest_finished_index_attempts = cast(
|
||||
list[IndexAttempt], latest_finished_index_attempts
|
||||
)
|
||||
latest_successful_index_attempts = cast(
|
||||
list[IndexAttempt], latest_successful_index_attempts
|
||||
)
|
||||
|
||||
document_count_info = get_document_counts_for_all_cc_pairs(db_session)
|
||||
|
||||
@@ -1206,42 +1219,48 @@ def get_connector_indexing_status(
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
|
||||
cc_pair_to_latest_index_attempt: dict[tuple[int, int], IndexAttempt] = {
|
||||
(
|
||||
attempt.connector_credential_pair.connector_id,
|
||||
attempt.connector_credential_pair.credential_id,
|
||||
): attempt
|
||||
for attempt in latest_index_attempts
|
||||
}
|
||||
def _attempt_lookup(
|
||||
attempts: list[IndexAttempt],
|
||||
) -> dict[int, IndexAttempt]:
|
||||
return {attempt.connector_credential_pair_id: attempt for attempt in attempts}
|
||||
|
||||
cc_pair_to_latest_finished_index_attempt: dict[tuple[int, int], IndexAttempt] = {
|
||||
(
|
||||
attempt.connector_credential_pair.connector_id,
|
||||
attempt.connector_credential_pair.credential_id,
|
||||
): attempt
|
||||
for attempt in latest_finished_index_attempts
|
||||
}
|
||||
cc_pair_to_latest_index_attempt = _attempt_lookup(latest_index_attempts)
|
||||
cc_pair_to_latest_finished_index_attempt = _attempt_lookup(
|
||||
latest_finished_index_attempts
|
||||
)
|
||||
cc_pair_to_latest_successful_index_attempt = _attempt_lookup(
|
||||
latest_successful_index_attempts
|
||||
)
|
||||
|
||||
def build_connector_indexing_status(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
is_editable: bool,
|
||||
) -> ConnectorIndexingStatusLite | None:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
return None
|
||||
|
||||
latest_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
)
|
||||
latest_attempt = cc_pair_to_latest_index_attempt.get(cc_pair.id)
|
||||
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
cc_pair.id
|
||||
)
|
||||
latest_successful_attempt = cc_pair_to_latest_successful_index_attempt.get(
|
||||
cc_pair.id
|
||||
)
|
||||
doc_count = cc_pair_to_document_cnt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id), 0
|
||||
)
|
||||
|
||||
return _get_connector_indexing_status_lite(
|
||||
cc_pair, latest_attempt, latest_finished_attempt, is_editable, doc_count
|
||||
cc_pair,
|
||||
latest_attempt,
|
||||
latest_finished_attempt,
|
||||
(
|
||||
latest_successful_attempt.time_started
|
||||
if latest_successful_attempt
|
||||
else None
|
||||
),
|
||||
is_editable,
|
||||
doc_count,
|
||||
)
|
||||
|
||||
# Process editable cc_pairs
|
||||
@@ -1402,6 +1421,7 @@ def _get_connector_indexing_status_lite(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
latest_index_attempt: IndexAttempt | None,
|
||||
latest_finished_index_attempt: IndexAttempt | None,
|
||||
last_successful_index_time: datetime | None,
|
||||
is_editable: bool,
|
||||
document_cnt: int,
|
||||
) -> ConnectorIndexingStatusLite | None:
|
||||
@@ -1435,7 +1455,7 @@ def _get_connector_indexing_status_lite(
|
||||
else None
|
||||
),
|
||||
last_status=latest_index_attempt.status if latest_index_attempt else None,
|
||||
last_success=cc_pair.last_successful_index_time,
|
||||
last_success=last_successful_index_time,
|
||||
docs_indexed=document_cnt,
|
||||
latest_index_attempt_docs_indexed=(
|
||||
latest_index_attempt.total_docs_indexed if latest_index_attempt else None
|
||||
|
||||
@@ -330,6 +330,7 @@ class CCPairFullInfo(BaseModel):
|
||||
num_docs_indexed: int, # not ideal, but this must be computed separately
|
||||
is_editable_for_current_user: bool,
|
||||
indexing: bool,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
last_permission_sync_attempt_status: PermissionSyncStatus | None = None,
|
||||
permission_syncing: bool = False,
|
||||
last_permission_sync_attempt_finished: datetime | None = None,
|
||||
@@ -382,9 +383,7 @@ class CCPairFullInfo(BaseModel):
|
||||
creator_email=(
|
||||
cc_pair_model.creator.email if cc_pair_model.creator else None
|
||||
),
|
||||
last_indexed=(
|
||||
last_index_attempt.time_started if last_index_attempt else None
|
||||
),
|
||||
last_indexed=last_successful_index_time,
|
||||
last_pruned=cc_pair_model.last_pruned,
|
||||
last_full_permission_sync=cls._get_last_full_permission_sync(cc_pair_model),
|
||||
overall_indexing_speed=overall_indexing_speed,
|
||||
|
||||
@@ -6978,9 +6978,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/flatted": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz",
|
||||
"integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==",
|
||||
"version": "3.4.2",
|
||||
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
|
||||
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
|
||||
0
backend/onyx/server/features/hooks/__init__.py
Normal file
0
backend/onyx/server/features/hooks/__init__.py
Normal file
453
backend/onyx/server/features/hooks/api.py
Normal file
453
backend/onyx/server/features/hooks/api.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
from onyx.db.hook import get_hook_execution_logs
|
||||
from onyx.db.hook import get_hooks
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookExecutionRecord
|
||||
from onyx.hooks.models import HookPointMetaResponse
|
||||
from onyx.hooks.models import HookResponse
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
|
||||
return HookResponse(
|
||||
id=hook.id,
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
is_reachable=hook.is_reachable,
|
||||
creator_email=(
|
||||
creator_email
|
||||
if creator_email is not None
|
||||
else (hook.creator.email if hook.creator else None)
|
||||
),
|
||||
created_at=hook.created_at,
|
||||
updated_at=hook.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_hook_or_404(
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
include_creator=include_creator,
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
|
||||
return hook
|
||||
|
||||
|
||||
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
|
||||
"""Raise an appropriate OnyxError for a non-passed validation result."""
|
||||
if validation.status == HookValidateStatus.auth_failed:
|
||||
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
|
||||
if validation.status == HookValidateStatus.timeout:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_endpoint(
|
||||
endpoint_url: str,
|
||||
api_key: str | None,
|
||||
timeout_seconds: float,
|
||||
) -> HookValidateResponse:
|
||||
"""Check whether endpoint_url is reachable by sending an empty POST request.
|
||||
|
||||
We use POST since hook endpoints expect POST requests. The server will typically
|
||||
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
|
||||
the server is up and routable. A 401/403 response returns auth_failed
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
|
||||
response = client.post(endpoint_url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed,
|
||||
error_message=f"Authentication failed (HTTP {response.status_code})",
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.timeout,
|
||||
error_message="Endpoint timed out — consider increasing timeout_seconds.",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connection error for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
HookPointMetaResponse(
|
||||
hook_point=spec.hook_point,
|
||||
display_name=spec.display_name,
|
||||
description=spec.description,
|
||||
docs_url=spec.docs_url,
|
||||
input_schema=spec.input_schema,
|
||||
output_schema=spec.output_schema,
|
||||
default_timeout_seconds=spec.default_timeout_seconds,
|
||||
default_fail_strategy=spec.default_fail_strategy,
|
||||
fail_hard_description=spec.fail_hard_description,
|
||||
)
|
||||
for spec in get_all_specs()
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
hooks = get_hooks(db_session=db_session, include_creator=True)
|
||||
return [_hook_to_response(h) for h in hooks]
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = create_hook__no_commit(
|
||||
db_session=db_session,
|
||||
name=req.name,
|
||||
hook_point=req.hook_point,
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.patch("/{hook_id}")
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
|
||||
endpoint is re-validated using the effective values. For active hooks the update is
|
||||
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
|
||||
the update goes through regardless and is_reachable is updated to reflect the result.
|
||||
|
||||
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
|
||||
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
|
||||
"""
|
||||
# api_key: UNSET = no change, None = clear, value = update
|
||||
api_key: str | None | UnsetType
|
||||
if "api_key" not in req.model_fields_set:
|
||||
api_key = UNSET
|
||||
elif req.api_key is None:
|
||||
api_key = None
|
||||
else:
|
||||
api_key = req.api_key.get_secret_value()
|
||||
|
||||
endpoint_url_changing = "endpoint_url" in req.model_fields_set
|
||||
api_key_changing = not isinstance(api_key, UnsetType)
|
||||
timeout_changing = "timeout_seconds" in req.model_fields_set
|
||||
|
||||
validated_is_reachable: bool | None = None
|
||||
if endpoint_url_changing or api_key_changing or timeout_changing:
|
||||
existing = _get_hook_or_404(db_session, hook_id)
|
||||
effective_url: str = (
|
||||
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
|
||||
)
|
||||
effective_api_key: str | None = (
|
||||
(api_key if not isinstance(api_key, UnsetType) else None)
|
||||
if api_key_changing
|
||||
else (
|
||||
existing.api_key.get_value(apply_mask=False)
|
||||
if existing.api_key
|
||||
else None
|
||||
)
|
||||
)
|
||||
effective_timeout: float = (
|
||||
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
)
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=effective_url,
|
||||
api_key=effective_api_key,
|
||||
timeout_seconds=effective_timeout,
|
||||
)
|
||||
if existing.is_active and validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
validated_is_reachable = validation.status == HookValidateStatus.passed
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
name=req.name,
|
||||
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds,
|
||||
is_reachable=validated_is_reachable,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
# Persist is_reachable=False in a separate session so the request
|
||||
# session has no commits on the failure path and the transaction
|
||||
# boundary stays clean.
|
||||
if hook.is_reachable is not False:
|
||||
with get_session_with_current_tenant() as side_session:
|
||||
update_hook__no_commit(
|
||||
db_session=side_session, hook_id=hook_id, is_reachable=False
|
||||
)
|
||||
side_session.commit()
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
validation_passed = validation.status == HookValidateStatus.passed
|
||||
if hook.is_reachable != validation_passed:
|
||||
update_hook__no_commit(
|
||||
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
|
||||
)
|
||||
db_session.commit()
|
||||
return validation
|
||||
|
||||
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=False,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution log endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{hook_id}/execution-logs")
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
_get_hook_or_404(db_session, hook_id)
|
||||
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
|
||||
return [
|
||||
HookExecutionRecord(
|
||||
error_message=log.error_message,
|
||||
status_code=log.status_code,
|
||||
duration_ms=log.duration_ms,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
@@ -119,8 +119,8 @@ admin_agents_router = APIRouter(prefix=ADMIN_AGENTS_RESOURCE)
|
||||
agents_router = APIRouter(prefix=AGENTS_RESOURCE)
|
||||
|
||||
|
||||
class IsVisibleRequest(BaseModel):
|
||||
is_visible: bool
|
||||
class IsListedRequest(BaseModel):
|
||||
is_listed: bool
|
||||
|
||||
|
||||
class IsPublicRequest(BaseModel):
|
||||
@@ -128,19 +128,19 @@ class IsPublicRequest(BaseModel):
|
||||
|
||||
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
is_featured: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@admin_router.patch("/{persona_id}/listed")
|
||||
def patch_persona_visibility(
|
||||
persona_id: int,
|
||||
is_visible_request: IsVisibleRequest,
|
||||
is_listed_request: IsListedRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_visibility(
|
||||
persona_id=persona_id,
|
||||
is_visible=is_visible_request.is_visible,
|
||||
is_listed=is_listed_request.is_listed,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
@@ -175,7 +175,7 @@ def patch_persona_featured_status(
|
||||
try:
|
||||
update_persona_featured(
|
||||
persona_id=persona_id,
|
||||
featured=is_featured_request.featured,
|
||||
is_featured=is_featured_request.is_featured,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
featured: bool = False
|
||||
is_featured: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -165,9 +165,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
icon_name: str | None
|
||||
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
is_listed: bool
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_featured: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -218,9 +218,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
icon_name=persona.icon_name,
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
is_listed=persona.is_listed,
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_featured=persona.is_featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -236,13 +236,13 @@ class PersonaSnapshot(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
is_listed: bool
|
||||
uploaded_image_id: str | None
|
||||
icon_name: str | None
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_featured: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
tools: list[ToolSnapshot]
|
||||
@@ -271,12 +271,12 @@ class PersonaSnapshot(BaseModel):
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
is_listed=persona.is_listed,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_featured=persona.is_featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
tools=[
|
||||
@@ -337,12 +337,12 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
is_listed=persona.is_listed,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_featured=persona.is_featured,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
|
||||
@@ -351,7 +351,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
persona_id_to_is_featured: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -370,11 +370,13 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
persona_id_to_is_featured = {
|
||||
persona.id: persona.is_featured for persona in personas
|
||||
}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
persona_id_to_is_featured=persona_id_to_is_featured,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
featured=False,
|
||||
is_featured=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import Depends
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__ as onyx_version
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import is_user_admin
|
||||
@@ -79,6 +80,7 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
version=onyx_version,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -104,3 +104,5 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -53,8 +53,12 @@ logger = setup_logger()
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
# Vespa metadata filters for overflowing user files. These are NOT the
|
||||
# IDs of the current project/persona — they are only set when the
|
||||
# project's/persona's user files didn't fit in the LLM context window and
|
||||
# must be found via vector DB search instead.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -180,8 +184,8 @@ def construct_tools(
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -396,6 +400,7 @@ def construct_tools(
|
||||
tool_definition=saved_tool.mcp_input_schema or {},
|
||||
connection_config=connection_config,
|
||||
user_email=user_email,
|
||||
user_id=str(user.id),
|
||||
user_oauth_token=mcp_user_oauth_token,
|
||||
additional_headers=additional_mcp_headers,
|
||||
)
|
||||
@@ -428,8 +433,8 @@ def construct_tools(
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -47,6 +49,7 @@ class MCPTool(Tool[None]):
|
||||
tool_definition: dict[str, Any],
|
||||
connection_config: MCPConnectionConfig | None = None,
|
||||
user_email: str = "",
|
||||
user_id: str = "",
|
||||
user_oauth_token: str | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
@@ -56,6 +59,7 @@ class MCPTool(Tool[None]):
|
||||
self.mcp_server = mcp_server
|
||||
self.connection_config = connection_config
|
||||
self.user_email = user_email
|
||||
self._user_id = user_id
|
||||
self._user_oauth_token = user_oauth_token
|
||||
self._additional_headers = additional_headers or {}
|
||||
|
||||
@@ -198,12 +202,42 @@ class MCPTool(Tool[None]):
|
||||
llm_facing_response=llm_facing_response,
|
||||
)
|
||||
|
||||
# For OAuth servers, construct OAuthClientProvider so the MCP SDK
|
||||
# can refresh expired tokens automatically
|
||||
auth: OAuthClientProvider | None = None
|
||||
if (
|
||||
self.mcp_server.auth_type == MCPAuthenticationType.OAUTH
|
||||
and self.connection_config is not None
|
||||
and self._user_id
|
||||
):
|
||||
if self.mcp_server.transport == MCPTransport.SSE:
|
||||
logger.warning(
|
||||
f"MCP tool '{self._name}': OAuth token refresh is not supported "
|
||||
f"for SSE transport — auth provider will be ignored. "
|
||||
f"Re-authentication may be required after token expiry."
|
||||
)
|
||||
else:
|
||||
from onyx.server.features.mcp.api import UNUSED_RETURN_PATH
|
||||
from onyx.server.features.mcp.api import make_oauth_provider
|
||||
|
||||
# user_id is the requesting user's UUID; safe here because
|
||||
# UNUSED_RETURN_PATH ensures redirect_handler raises immediately
|
||||
# and user_id is never consulted for Redis state lookups.
|
||||
auth = make_oauth_provider(
|
||||
self.mcp_server,
|
||||
self._user_id,
|
||||
UNUSED_RETURN_PATH,
|
||||
self.connection_config.id,
|
||||
None,
|
||||
)
|
||||
|
||||
tool_result = call_mcp_tool(
|
||||
self.mcp_server.server_url,
|
||||
self._name,
|
||||
llm_kwargs,
|
||||
connection_headers=headers,
|
||||
transport=self.mcp_server.transport or MCPTransport.STREAMABLE_HTTP,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
logger.info(f"MCP tool '{self._name}' executed successfully")
|
||||
@@ -248,6 +282,7 @@ class MCPTool(Tool[None]):
|
||||
"invalid token",
|
||||
"invalid api key",
|
||||
"invalid credentials",
|
||||
"please reconnect to the server",
|
||||
]
|
||||
|
||||
is_auth_error = any(
|
||||
|
||||
@@ -764,8 +764,7 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
tags=None,
|
||||
access_control_list=access_control_list,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
user_file_ids=None,
|
||||
project_id=None,
|
||||
project_id_filter=None,
|
||||
)
|
||||
|
||||
def _merge_indexed_and_crawled_results(
|
||||
|
||||
@@ -244,10 +244,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't
|
||||
# fit in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -261,8 +262,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.project_id_filter = project_id_filter
|
||||
self.persona_id_filter = persona_id_filter
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -451,13 +452,15 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
# For projects, the search scope is the project and has no other limits
|
||||
user_selected_filters=(
|
||||
self.user_selected_filters if self.project_id is None else None
|
||||
self.user_selected_filters
|
||||
if self.project_id_filter is None
|
||||
else None
|
||||
),
|
||||
bypass_acl=self.bypass_acl,
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
project_id_filter=self.project_id_filter,
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
@@ -574,7 +577,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
# Federated retrieval functions (non-Slack; Slack is separate)
|
||||
if self.project_id is not None:
|
||||
if self.project_id_filter is not None:
|
||||
# Project mode ignores user filters → no federated sources
|
||||
prefetch_source_types = None
|
||||
else:
|
||||
@@ -587,16 +590,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
user_file_ids = (
|
||||
[uf.id for uf in self.persona.user_files] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=persona_document_sets,
|
||||
user_file_ids=user_file_ids,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -189,3 +189,30 @@ def mt_cloud_identify(
|
||||
attribute="identify_user",
|
||||
fallback=noop_fallback,
|
||||
)(distinct_id, properties)
|
||||
|
||||
|
||||
def mt_cloud_alias(
|
||||
distinct_id: str,
|
||||
anonymous_id: str,
|
||||
) -> None:
|
||||
"""Link an anonymous distinct_id to an identified user (Cloud only)."""
|
||||
if not MULTI_TENANT:
|
||||
return
|
||||
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.posthog_client",
|
||||
attribute="alias_user",
|
||||
fallback=noop_fallback,
|
||||
)(distinct_id, anonymous_id)
|
||||
|
||||
|
||||
def mt_cloud_get_anon_id(request: Any) -> str | None:
|
||||
"""Extract the anonymous distinct_id from the app PostHog cookie (Cloud only)."""
|
||||
if not MULTI_TENANT or not request:
|
||||
return None
|
||||
|
||||
return fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.posthog_client",
|
||||
attribute="get_anon_id_from_request",
|
||||
fallback=noop_fallback,
|
||||
)(request)
|
||||
|
||||
@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
def validate_outbound_http_url(
|
||||
url: str,
|
||||
*,
|
||||
allow_private_network: bool = False,
|
||||
https_only: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
allow_private_network: If True, skip private/reserved IP checks.
|
||||
https_only: If True, reject http:// URLs (only https:// is allowed).
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
if https_only:
|
||||
if parsed.scheme != "https":
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
|
||||
)
|
||||
elif parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
170
backend/scripts/debugging/opensearch/benchmark_retrieval.py
Normal file
170
backend/scripts/debugging/opensearch/benchmark_retrieval.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmarks OpenSearchDocumentIndex latency.
|
||||
|
||||
Requires Onyx to be running as it reads search settings from the database.
|
||||
|
||||
Usage:
|
||||
source .venv/bin/activate
|
||||
python backend/scripts/debugging/opensearch/benchmark_retrieval.py --help
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import statistics
|
||||
import time
|
||||
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from scripts.debugging.opensearch.constants import DEV_TENANT_ID
|
||||
from scripts.debugging.opensearch.embedding_io import load_query_embedding_from_file
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
DEFAULT_N = 50
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def add_query_embedding_argument(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--embedding-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the query embedding file.",
|
||||
)
|
||||
|
||||
def add_query_string_argument(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--query",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Query string.",
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A benchmarking tool to measure OpenSearch retrieval latency."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
type=int,
|
||||
default=DEFAULT_N,
|
||||
help=f"Number of samples to take (default: {DEFAULT_N}).",
|
||||
)
|
||||
subparsers = parser.add_subparsers(
|
||||
dest="query_type",
|
||||
help="Query type to benchmark.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
hybrid_parser = subparsers.add_parser(
|
||||
"hybrid", help="Benchmark hybrid retrieval latency."
|
||||
)
|
||||
add_query_embedding_argument(hybrid_parser)
|
||||
add_query_string_argument(hybrid_parser)
|
||||
|
||||
keyword_parser = subparsers.add_parser(
|
||||
"keyword", help="Benchmark keyword retrieval latency."
|
||||
)
|
||||
add_query_string_argument(keyword_parser)
|
||||
|
||||
semantic_parser = subparsers.add_parser(
|
||||
"semantic", help="Benchmark semantic retrieval latency."
|
||||
)
|
||||
add_query_embedding_argument(semantic_parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n < 1:
|
||||
parser.error("Number of samples (-n) must be at least 1.")
|
||||
|
||||
if MULTI_TENANT:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(DEV_TENANT_ID)
|
||||
|
||||
SqlEngine.init_engine(pool_size=1, max_overflow=0)
|
||||
with get_session_with_current_tenant() as session:
|
||||
search_settings = get_current_search_settings(session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
tenant_id=get_current_tenant_id(),
|
||||
)
|
||||
|
||||
if args.query_type == "hybrid":
|
||||
embedding = load_query_embedding_from_file(args.embedding_file_path)
|
||||
search_callable = lambda: index.hybrid_retrieval( # noqa: E731
|
||||
query=args.query,
|
||||
query_embedding=embedding,
|
||||
final_keywords=None,
|
||||
# This arg doesn't do anything right now.
|
||||
query_type=QueryType.KEYWORD,
|
||||
filters=filters,
|
||||
num_to_retrieve=NUM_RETURNED_HITS,
|
||||
)
|
||||
elif args.query_type == "keyword":
|
||||
search_callable = lambda: index.keyword_retrieval( # noqa: E731
|
||||
query=args.query,
|
||||
filters=filters,
|
||||
num_to_retrieve=NUM_RETURNED_HITS,
|
||||
)
|
||||
elif args.query_type == "semantic":
|
||||
embedding = load_query_embedding_from_file(args.embedding_file_path)
|
||||
search_callable = lambda: index.semantic_retrieval( # noqa: E731
|
||||
query_embedding=embedding,
|
||||
filters=filters,
|
||||
num_to_retrieve=NUM_RETURNED_HITS,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid query type: {args.query_type}")
|
||||
|
||||
print(f"Running {args.n} invocations of {args.query_type} retrieval...")
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(args.n):
|
||||
start = time.perf_counter()
|
||||
results = search_callable()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
# Print the current iteration and its elapsed time on the same line.
|
||||
print(
|
||||
f" [{i:>{len(str(args.n))}}] {elapsed_ms:7.1f} ms ({len(results)} results) (top result doc ID, chunk idx: {results[0].document_id if results else 'N/A'}, {results[0].chunk_id if results else 'N/A'})",
|
||||
end="\r",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Results over {args.n} invocations:")
|
||||
print(f" mean: {statistics.mean(latencies):7.1f} ms")
|
||||
print(
|
||||
f" stdev: {statistics.stdev(latencies):7.1f} ms"
|
||||
if args.n > 1
|
||||
else " stdev: N/A (only 1 sample)"
|
||||
)
|
||||
print(f" max: {max(latencies):7.1f} ms (i: {latencies.index(max(latencies))})")
|
||||
print(f" min: {min(latencies):7.1f} ms (i: {latencies.index(min(latencies))})")
|
||||
if args.n >= 20:
|
||||
print(f" p50: {statistics.median(latencies):7.1f} ms")
|
||||
print(f" p95: {statistics.quantiles(latencies, n=20)[-1]:7.1f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
backend/scripts/debugging/opensearch/constants.py
Normal file
1
backend/scripts/debugging/opensearch/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
DEV_TENANT_ID = "tenant_dev"
|
||||
64
backend/scripts/debugging/opensearch/embed_and_save.py
Normal file
64
backend/scripts/debugging/opensearch/embed_and_save.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Embeds a query and saves the embedding to a file.
|
||||
|
||||
Requires Onyx to be running as it reads search settings from the database.
|
||||
|
||||
Usage:
|
||||
source .venv/bin/activate
|
||||
python backend/scripts/debugging/opensearch/embed_and_save.py --help
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from scripts.debugging.opensearch.constants import DEV_TENANT_ID
|
||||
from scripts.debugging.opensearch.embedding_io import save_query_embedding_to_file
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A tool to embed a query and save the embedding to a file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--query",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Query string to embed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output file to save the embedding to.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if MULTI_TENANT:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(DEV_TENANT_ID)
|
||||
|
||||
SqlEngine.init_engine(pool_size=1, max_overflow=0)
|
||||
with get_session_with_current_tenant() as session:
|
||||
start = time.perf_counter()
|
||||
query_embedding = get_query_embedding(
|
||||
query=args.query,
|
||||
db_session=session,
|
||||
embedding_model=None,
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
save_query_embedding_to_file(query_embedding, args.file_path)
|
||||
print(
|
||||
f"Query embedding of dimension {len(query_embedding)} generated in {elapsed_ms:.1f} ms and saved to {args.file_path}."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
backend/scripts/debugging/opensearch/embedding_io.py
Normal file
43
backend/scripts/debugging/opensearch/embedding_io.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
def load_query_embedding_from_file(file_path: str) -> Embedding:
|
||||
"""Returns an embedding vector read from a file.
|
||||
|
||||
The file should be formatted as follows:
|
||||
- The first line should contain an integer representing the embedding
|
||||
dimension.
|
||||
- Every subsequent line should contain a float value representing a
|
||||
component of the embedding vector.
|
||||
- The size and embedding content should all be delimited by a newline.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the embedding vector.
|
||||
|
||||
Returns:
|
||||
Embedding: The embedding vector.
|
||||
"""
|
||||
with open(file_path, "r") as f:
|
||||
dimension = int(f.readline().strip())
|
||||
embedding = [float(line.strip()) for line in f.readlines()]
|
||||
assert len(embedding) == dimension, "Embedding dimension mismatch."
|
||||
return embedding
|
||||
|
||||
|
||||
def save_query_embedding_to_file(embedding: Embedding, file_path: str) -> None:
|
||||
"""Saves an embedding vector to a file.
|
||||
|
||||
The file will be formatted as follows:
|
||||
- The first line will contain the embedding dimension.
|
||||
- Every subsequent line will contain a float value representing a
|
||||
component of the embedding vector.
|
||||
- The size and embedding content will all be delimited by a newline.
|
||||
|
||||
Args:
|
||||
embedding: The embedding vector to save.
|
||||
file_path: Path to the file to save the embedding vector to.
|
||||
"""
|
||||
with open(file_path, "w") as f:
|
||||
f.write(f"{len(embedding)}\n")
|
||||
for component in embedding:
|
||||
f.write(f"{component}\n")
|
||||
@@ -2,9 +2,10 @@
|
||||
"""A utility to interact with OpenSearch.
|
||||
|
||||
Usage:
|
||||
python3 opensearch_debug.py --help
|
||||
python3 opensearch_debug.py list
|
||||
python3 opensearch_debug.py delete <index_name>
|
||||
source .venv/bin/activate
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py --help
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py list
|
||||
python backend/scripts/debugging/opensearch/opensearch_debug.py delete <index_name>
|
||||
|
||||
Environment Variables:
|
||||
OPENSEARCH_HOST: OpenSearch host
|
||||
@@ -107,16 +108,15 @@ def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A utility to interact with OpenSearch."
|
||||
)
|
||||
add_standard_arguments(parser)
|
||||
subparsers = parser.add_subparsers(
|
||||
dest="command", help="Command to execute.", required=True
|
||||
)
|
||||
|
||||
list_parser = subparsers.add_parser("list", help="List all indices with info.")
|
||||
add_standard_arguments(list_parser)
|
||||
subparsers.add_parser("list", help="List all indices with info.")
|
||||
|
||||
delete_parser = subparsers.add_parser("delete", help="Delete an index.")
|
||||
delete_parser.add_argument("index", help="Index name.", type=str)
|
||||
add_standard_arguments(delete_parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
db_session=db_session,
|
||||
tool_ids=[], # Explicitly no tools
|
||||
document_set_ids=None,
|
||||
is_visible=True,
|
||||
is_listed=True,
|
||||
)
|
||||
|
||||
# Create a chat session with our test persona
|
||||
|
||||
@@ -91,7 +91,7 @@ def _create_test_persona(
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_listed=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -63,7 +63,7 @@ def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_listed=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Shared fixtures for document_index external dependency tests.
|
||||
|
||||
Provides Vespa and OpenSearch index setup, tenant context, and chunk helpers.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
EMBEDDING_DIM = 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
"""Create a chunk suitable for external dependency testing (128-dim embeddings)."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
embeddings = ChunkEmbedding(
|
||||
full_embedding=[1.0] + [0.0] * (EMBEDDING_DIM - 1),
|
||||
mini_chunk_embeddings=[],
|
||||
)
|
||||
source_document = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[],
|
||||
metadata={},
|
||||
title="test title",
|
||||
)
|
||||
return DocMetadataAwareIndexChunk(
|
||||
tenant_id=tenant_id,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=0,
|
||||
ancestor_hierarchy_node_ids=[],
|
||||
embeddings=embeddings,
|
||||
title_embedding=[1.0] + [0.0] * (EMBEDDING_DIM - 1),
|
||||
source_document=source_document,
|
||||
title_prefix="",
|
||||
metadata_suffix_keyword="",
|
||||
metadata_suffix_semantic="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links={0: ""},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
)
|
||||
|
||||
|
||||
def make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tenant_context() -> Generator[None, None, None]:
|
||||
"""Sets up tenant context for testing."""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_index_name() -> Generator[str, None, None]:
|
||||
yield f"test_index_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def httpx_client() -> Generator[httpx.Client, None, None]:
|
||||
client = get_vespa_http_client()
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vespa_index(
|
||||
httpx_client: httpx.Client,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[VespaIndex, None, None]:
|
||||
"""Create a Vespa index, wait for schema readiness, and yield it."""
|
||||
vespa_idx = VespaIndex(
|
||||
index_name=test_index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
backend_dir = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
||||
)
|
||||
with patch("os.getcwd", return_value=backend_dir):
|
||||
vespa_idx.ensure_indices_exist(
|
||||
primary_embedding_dim=EMBEDDING_DIM,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
if not wait_for_vespa_with_timeout(wait_limit=90):
|
||||
pytest.fail("Vespa is not available.")
|
||||
|
||||
# Wait until the schema is actually ready for writes on content nodes. We
|
||||
# probe by attempting a PUT; 200 means the schema is live, 400 means not
|
||||
# yet. This is only temporary until we entirely move off of Vespa.
|
||||
probe_doc = {
|
||||
"fields": {
|
||||
"document_id": "__probe__",
|
||||
"chunk_id": 0,
|
||||
"blurb": "",
|
||||
"title": "",
|
||||
"skip_title": True,
|
||||
"content": "",
|
||||
"content_summary": "",
|
||||
"source_type": "file",
|
||||
"source_links": "null",
|
||||
"semantic_identifier": "",
|
||||
"section_continuation": False,
|
||||
"large_chunk_reference_ids": [],
|
||||
"metadata": "{}",
|
||||
"metadata_list": [],
|
||||
"metadata_suffix": "",
|
||||
"chunk_context": "",
|
||||
"doc_summary": "",
|
||||
"embeddings": {"full_chunk": [1.0] + [0.0] * (EMBEDDING_DIM - 1)},
|
||||
"access_control_list": {},
|
||||
"document_sets": {},
|
||||
"image_file_name": None,
|
||||
"user_project": [],
|
||||
"personas": [],
|
||||
"boost": 0.0,
|
||||
"aggregated_chunk_boost_factor": 0.0,
|
||||
"primary_owners": [],
|
||||
"secondary_owners": [],
|
||||
}
|
||||
}
|
||||
probe_url = (
|
||||
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
|
||||
)
|
||||
schema_ready = False
|
||||
for _ in range(60):
|
||||
resp = httpx_client.post(probe_url, json=probe_doc)
|
||||
if resp.status_code == 200:
|
||||
schema_ready = True
|
||||
httpx_client.delete(probe_url)
|
||||
break
|
||||
time.sleep(1)
|
||||
if not schema_ready:
|
||||
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
|
||||
|
||||
yield vespa_idx
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_old_index(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
|
||||
"""Create an OpenSearch index via the old adapter and yield it."""
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
pytest.fail("OpenSearch is not available.")
|
||||
|
||||
opensearch_idx = OpenSearchOldDocumentIndex(
|
||||
index_name=test_index_name,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_name=None,
|
||||
secondary_embedding_dim=None,
|
||||
secondary_embedding_precision=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
opensearch_idx.ensure_indices_exist(
|
||||
primary_embedding_dim=EMBEDDING_DIM,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
yield opensearch_idx
|
||||
@@ -0,0 +1,203 @@
|
||||
"""External dependency tests for the new DocumentIndex interface.
|
||||
|
||||
These tests assume Vespa and OpenSearch are running.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces_new import DocumentIndex as DocumentIndexNew
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
from tests.external_dependency_unit.document_index.conftest import (
|
||||
make_indexing_metadata,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vespa_document_index(
|
||||
vespa_index: VespaIndex, # noqa: ARG001 — ensures schema exists
|
||||
httpx_client: httpx.Client,
|
||||
test_index_name: str,
|
||||
) -> Generator[VespaDocumentIndex, None, None]:
|
||||
yield VespaDocumentIndex(
|
||||
index_name=test_index_name,
|
||||
tenant_state=TenantState(tenant_id=TEST_TENANT_ID, multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_document_index(
|
||||
opensearch_old_index: OpenSearchOldDocumentIndex, # noqa: ARG001 — ensures index exists
|
||||
test_index_name: str,
|
||||
) -> Generator[OpenSearchDocumentIndex, None, None]:
|
||||
yield OpenSearchDocumentIndex(
|
||||
tenant_state=TenantState(tenant_id=TEST_TENANT_ID, multitenant=False),
|
||||
index_name=test_index_name,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def document_indices(
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_document_index: OpenSearchDocumentIndex,
|
||||
) -> Generator[list[DocumentIndexNew], None, None]:
|
||||
yield [opensearch_document_index, vespa_document_index]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentIndexNew:
|
||||
"""Tests the new DocumentIndex interface against real Vespa and OpenSearch."""
|
||||
|
||||
def test_index_single_new_doc(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Indexing a single new document returns one record with already_existed=False."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_single_new_{uuid.uuid4().hex[:8]}"
|
||||
chunk = make_chunk(doc_id)
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[1])
|
||||
|
||||
results = document_index.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
def test_index_existing_doc_already_existed_true(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Re-indexing a doc with previous chunks returns already_existed=True."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_existing_{uuid.uuid4().hex[:8]}"
|
||||
chunk = make_chunk(doc_id)
|
||||
|
||||
# First index — brand new document.
|
||||
metadata_first = make_indexing_metadata(
|
||||
[doc_id], old_counts=[0], new_counts=[1]
|
||||
)
|
||||
document_index.index(chunks=[chunk], indexing_metadata=metadata_first)
|
||||
|
||||
# Allow near-real-time indexing to settle (needed for Vespa).
|
||||
time.sleep(1)
|
||||
|
||||
# Re-index — old_chunk_cnt=1 signals the document already existed.
|
||||
metadata_second = make_indexing_metadata(
|
||||
[doc_id], old_counts=[1], new_counts=[1]
|
||||
)
|
||||
results = document_index.index(
|
||||
chunks=[chunk], indexing_metadata=metadata_second
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].already_existed is True
|
||||
|
||||
def test_index_multiple_docs(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Indexing multiple documents returns one record per unique document."""
|
||||
for document_index in document_indices:
|
||||
doc1 = f"test_multi_1_{uuid.uuid4().hex[:8]}"
|
||||
doc2 = f"test_multi_2_{uuid.uuid4().hex[:8]}"
|
||||
chunks = [
|
||||
make_chunk(doc1, chunk_id=0),
|
||||
make_chunk(doc1, chunk_id=1),
|
||||
make_chunk(doc2, chunk_id=0),
|
||||
]
|
||||
metadata = make_indexing_metadata(
|
||||
[doc1, doc2], old_counts=[0, 0], new_counts=[2, 1]
|
||||
)
|
||||
|
||||
results = document_index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
result_map = {r.document_id: r.already_existed for r in results}
|
||||
assert len(result_map) == 2
|
||||
assert result_map[doc1] is False
|
||||
assert result_map[doc2] is False
|
||||
|
||||
def test_index_deduplicates_doc_ids_in_results(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Multiple chunks from the same document produce only one
|
||||
DocumentInsertionRecord."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_dedup_{uuid.uuid4().hex[:8]}"
|
||||
chunks = [make_chunk(doc_id, chunk_id=i) for i in range(5)]
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[5])
|
||||
|
||||
results = document_index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
|
||||
def test_index_mixed_new_and_existing_docs(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A batch with both new and existing documents returns the correct
|
||||
already_existed flag for each."""
|
||||
for document_index in document_indices:
|
||||
existing_doc = f"test_mixed_exist_{uuid.uuid4().hex[:8]}"
|
||||
new_doc = f"test_mixed_new_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Pre-index the existing document.
|
||||
pre_chunk = make_chunk(existing_doc)
|
||||
pre_metadata = make_indexing_metadata(
|
||||
[existing_doc], old_counts=[0], new_counts=[1]
|
||||
)
|
||||
document_index.index(chunks=[pre_chunk], indexing_metadata=pre_metadata)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# Now index a batch with the existing doc and a new doc.
|
||||
chunks = [
|
||||
make_chunk(existing_doc, chunk_id=0),
|
||||
make_chunk(new_doc, chunk_id=0),
|
||||
]
|
||||
metadata = make_indexing_metadata(
|
||||
[existing_doc, new_doc], old_counts=[1, 0], new_counts=[1, 1]
|
||||
)
|
||||
|
||||
results = document_index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
result_map = {r.document_id: r.already_existed for r in results}
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
@@ -1,275 +1,41 @@
|
||||
"""External dependency tests for the old DocumentIndex interface.
|
||||
|
||||
These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_available() -> Generator[None, None, None]:
|
||||
"""Verifies OpenSearch is running, fails the test if not."""
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
pytest.fail("OpenSearch is not available.")
|
||||
yield # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_index_name() -> Generator[str, None, None]:
|
||||
yield f"test_index_{uuid.uuid4().hex[:8]}" # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tenant_context() -> Generator[None, None, None]:
|
||||
"""Sets up tenant context for testing."""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield # Test runs here.
|
||||
finally:
|
||||
# Reset the tenant context after the test
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def httpx_client() -> Generator[httpx.Client, None, None]:
|
||||
client = get_vespa_http_client()
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vespa_document_index(
|
||||
httpx_client: httpx.Client,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[VespaIndex, None, None]:
|
||||
vespa_index = VespaIndex(
|
||||
index_name=test_index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
backend_dir = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
||||
)
|
||||
with patch("os.getcwd", return_value=backend_dir):
|
||||
vespa_index.ensure_indices_exist(
|
||||
primary_embedding_dim=128,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
# Verify Vespa is running, fails the test if not. Try 90 seconds for testing
|
||||
# in CI. We have to do this here because this endpoint only becomes live
|
||||
# once we create an index.
|
||||
if not wait_for_vespa_with_timeout(wait_limit=90):
|
||||
pytest.fail("Vespa is not available.")
|
||||
|
||||
# Wait until the schema is actually ready for writes on content nodes. We
|
||||
# probe by attempting a PUT; 200 means the schema is live, 400 means not
|
||||
# yet. This is so scuffed but running the test is really flakey otherwise;
|
||||
# this is only temporary until we entirely move off of Vespa.
|
||||
probe_doc = {
|
||||
"fields": {
|
||||
"document_id": "__probe__",
|
||||
"chunk_id": 0,
|
||||
"blurb": "",
|
||||
"title": "",
|
||||
"skip_title": True,
|
||||
"content": "",
|
||||
"content_summary": "",
|
||||
"source_type": "file",
|
||||
"source_links": "null",
|
||||
"semantic_identifier": "",
|
||||
"section_continuation": False,
|
||||
"large_chunk_reference_ids": [],
|
||||
"metadata": "{}",
|
||||
"metadata_list": [],
|
||||
"metadata_suffix": "",
|
||||
"chunk_context": "",
|
||||
"doc_summary": "",
|
||||
"embeddings": {"full_chunk": [1.0] + [0.0] * 127},
|
||||
"access_control_list": {},
|
||||
"document_sets": {},
|
||||
"image_file_name": None,
|
||||
"user_project": [],
|
||||
"personas": [],
|
||||
"boost": 0.0,
|
||||
"aggregated_chunk_boost_factor": 0.0,
|
||||
"primary_owners": [],
|
||||
"secondary_owners": [],
|
||||
}
|
||||
}
|
||||
schema_ready = False
|
||||
probe_url = (
|
||||
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
|
||||
)
|
||||
for _ in range(60):
|
||||
resp = httpx_client.post(probe_url, json=probe_doc)
|
||||
if resp.status_code == 200:
|
||||
schema_ready = True
|
||||
# Clean up the probe document.
|
||||
httpx_client.delete(probe_url)
|
||||
break
|
||||
time.sleep(1)
|
||||
if not schema_ready:
|
||||
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
|
||||
|
||||
yield vespa_index # Test runs here.
|
||||
|
||||
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
|
||||
# pressing; in CI we should be using fresh instances of dependencies each
|
||||
# time anyway.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_document_index(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
|
||||
opensearch_index = OpenSearchOldDocumentIndex(
|
||||
index_name=test_index_name,
|
||||
embedding_dim=128,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_name=None,
|
||||
secondary_embedding_dim=None,
|
||||
secondary_embedding_precision=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
opensearch_index.ensure_indices_exist(
|
||||
primary_embedding_dim=128,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
yield opensearch_index # Test runs here.
|
||||
|
||||
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
|
||||
# pressing; in CI we should be using fresh instances of dependencies each
|
||||
# time anyway.
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def document_indices(
|
||||
vespa_document_index: VespaIndex,
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex,
|
||||
vespa_index: VespaIndex,
|
||||
opensearch_old_index: OpenSearchOldDocumentIndex,
|
||||
) -> Generator[list[DocumentIndex], None, None]:
|
||||
# Ideally these are parametrized; doing so with pytest fixtures is tricky.
|
||||
yield [opensearch_document_index, vespa_document_index] # Test runs here.
|
||||
yield [opensearch_old_index, vespa_index]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chunks(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[list[DocMetadataAwareIndexChunk], None, None]:
|
||||
result = []
|
||||
chunk_count = 5
|
||||
doc_id = "test_doc"
|
||||
tenant_id = get_current_tenant_id()
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
document_sets: set[str] = set()
|
||||
user_project: list[int] = list()
|
||||
personas: list[int] = list()
|
||||
boost = 0
|
||||
blurb = "blurb"
|
||||
content = "content"
|
||||
title_prefix = ""
|
||||
doc_summary = ""
|
||||
chunk_context = ""
|
||||
title_embedding = [1.0] + [0] * 127
|
||||
# Full 0 vectors are not supported for cos similarity.
|
||||
embeddings = ChunkEmbedding(
|
||||
full_embedding=[1.0] + [0] * 127, mini_chunk_embeddings=[]
|
||||
)
|
||||
source_document = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="semantic identifier",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[],
|
||||
metadata={},
|
||||
title="title",
|
||||
)
|
||||
metadata_suffix_keyword = ""
|
||||
image_file_id = None
|
||||
source_links: dict[int, str] = {0: ""}
|
||||
ancestor_hierarchy_node_ids: list[int] = []
|
||||
for i in range(chunk_count):
|
||||
result.append(
|
||||
DocMetadataAwareIndexChunk(
|
||||
tenant_id=tenant_id,
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=0,
|
||||
ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids,
|
||||
embeddings=embeddings,
|
||||
title_embedding=title_embedding,
|
||||
source_document=source_document,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
metadata_suffix_semantic="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
chunk_id=i,
|
||||
blurb=blurb,
|
||||
content=content,
|
||||
source_links=source_links,
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=False,
|
||||
)
|
||||
)
|
||||
yield result # Test runs here.
|
||||
yield [make_chunk("test_doc", chunk_id=i) for i in range(5)]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -336,8 +102,8 @@ class TestDocumentIndexOld:
|
||||
project_persona_filters = IndexFilters(
|
||||
access_control_list=None,
|
||||
tenant_id=tenant_id,
|
||||
project_id=1,
|
||||
persona_id=2,
|
||||
project_id_filter=1,
|
||||
persona_id_filter=2,
|
||||
# We need this even though none of the chunks belong to a
|
||||
# document set because project_id and persona_id are only
|
||||
# additive filters in the event the agent has knowledge scope;
|
||||
|
||||
@@ -1,34 +1,30 @@
|
||||
"""Tests for OpenSearch assistant knowledge filter construction.
|
||||
|
||||
These tests verify that when an assistant (persona) has user files attached,
|
||||
the search filter includes those user file IDs in the assistant knowledge filter
|
||||
with OR logic (not AND), ensuring user files are discoverable alongside other
|
||||
knowledge types like attached documents and hierarchy nodes.
|
||||
|
||||
This prevents a regression where user_file_ids were added as a separate AND
|
||||
filter, making it impossible to find user files when the assistant also had
|
||||
attached documents or hierarchy nodes (since no document could match both).
|
||||
These tests verify that when an assistant (persona) has knowledge attached,
|
||||
the search filter includes the appropriate scope filters with OR logic (not AND),
|
||||
ensuring documents are discoverable across knowledge types like attached documents,
|
||||
hierarchy nodes, document sets, and persona/project user files.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
USER_FILE_ID = UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b")
|
||||
ATTACHED_DOCUMENT_ID = "https://docs.google.com/document/d/test-doc-id"
|
||||
HIERARCHY_NODE_ID = 42
|
||||
PERSONA_ID = 7
|
||||
|
||||
|
||||
def _get_search_filters(
|
||||
source_types: list[DocumentSource],
|
||||
user_file_ids: list[UUID],
|
||||
attached_document_ids: list[str] | None,
|
||||
hierarchy_node_ids: list[int] | None,
|
||||
persona_id_filter: int | None = None,
|
||||
document_sets: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return DocumentQuery._get_search_filters(
|
||||
tenant_state=TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False),
|
||||
@@ -36,15 +32,14 @@ def _get_search_filters(
|
||||
access_control_list=["user_email:test@example.com"],
|
||||
source_types=source_types,
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
document_sets=document_sets or [],
|
||||
project_id_filter=None,
|
||||
persona_id_filter=persona_id_filter,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=None,
|
||||
user_file_ids=user_file_ids,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
)
|
||||
@@ -53,137 +48,97 @@ def _get_search_filters(
|
||||
class TestAssistantKnowledgeFilter:
|
||||
"""Tests for assistant knowledge filter construction in OpenSearch queries."""
|
||||
|
||||
def test_user_file_ids_included_in_assistant_knowledge_filter(self) -> None:
|
||||
"""
|
||||
Tests that user_file_ids are included in the assistant knowledge filter
|
||||
with OR logic when the assistant has both user files and attached documents.
|
||||
|
||||
This prevents the regression where user files were ANDed with other
|
||||
knowledge types, making them unfindable.
|
||||
"""
|
||||
|
||||
# Under test: Call the filter construction method directly
|
||||
def test_persona_id_filter_added_when_knowledge_scope_exists(self) -> None:
|
||||
"""persona_id_filter should be OR'd into the knowledge scope filter
|
||||
when explicit knowledge attachments (attached_document_ids,
|
||||
hierarchy_node_ids, document_sets) are present."""
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
|
||||
user_file_ids=[USER_FILE_ID],
|
||||
source_types=[DocumentSource.FILE],
|
||||
attached_document_ids=[ATTACHED_DOCUMENT_ID],
|
||||
hierarchy_node_ids=[HIERARCHY_NODE_ID],
|
||||
persona_id_filter=PERSONA_ID,
|
||||
)
|
||||
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
assert knowledge_filter is not None, (
|
||||
"Expected to find an assistant knowledge filter with "
|
||||
"'minimum_should_match: 1'"
|
||||
)
|
||||
|
||||
should_clauses = knowledge_filter["bool"]["should"]
|
||||
persona_found = any(
|
||||
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
|
||||
== PERSONA_ID
|
||||
for clause in should_clauses
|
||||
)
|
||||
assert persona_found, (
|
||||
f"Expected persona_id={PERSONA_ID} filter on {PERSONAS_FIELD_NAME} "
|
||||
f"in should clauses. Got: {should_clauses}"
|
||||
)
|
||||
|
||||
def test_persona_id_filter_alone_creates_knowledge_scope(self) -> None:
|
||||
"""persona_id_filter IS a primary knowledge scope trigger — a persona
|
||||
with user files is explicit knowledge, so it should restrict
|
||||
search on its own."""
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[],
|
||||
attached_document_ids=None,
|
||||
hierarchy_node_ids=None,
|
||||
persona_id_filter=PERSONA_ID,
|
||||
)
|
||||
|
||||
# Postcondition: Find the assistant knowledge filter (bool with should clauses)
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
# Check if this is the knowledge filter (has minimum_should_match=1)
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
assert (
|
||||
knowledge_filter is not None
|
||||
), "Expected to find an assistant knowledge filter with 'minimum_should_match: 1'"
|
||||
|
||||
# The knowledge filter should have 3 should clauses (user files, attached docs, hierarchy nodes)
|
||||
should_clauses = knowledge_filter["bool"]["should"]
|
||||
assert (
|
||||
len(should_clauses) == 3
|
||||
), f"Expected 3 should clauses (user_file, attached_doc, hierarchy_node), got {len(should_clauses)}"
|
||||
|
||||
# Verify user_file_id is in one of the should clauses
|
||||
user_file_filter_found = False
|
||||
for should_clause in should_clauses:
|
||||
# The user file filter uses a nested bool with should for each file ID
|
||||
if "bool" in should_clause and "should" in should_clause["bool"]:
|
||||
for term_clause in should_clause["bool"]["should"]:
|
||||
if "term" in term_clause:
|
||||
term_value = term_clause["term"].get(DOCUMENT_ID_FIELD_NAME, {})
|
||||
if term_value.get("value") == str(USER_FILE_ID):
|
||||
user_file_filter_found = True
|
||||
break
|
||||
|
||||
assert user_file_filter_found, (
|
||||
f"Expected user_file_id {USER_FILE_ID} to be in the assistant knowledge "
|
||||
f"filter's should clauses. Filter structure: {knowledge_filter}"
|
||||
), "Expected persona_id_filter alone to create a knowledge scope filter"
|
||||
persona_found = any(
|
||||
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
|
||||
== PERSONA_ID
|
||||
for clause in knowledge_filter["bool"]["should"]
|
||||
)
|
||||
assert persona_found, (
|
||||
f"Expected persona_id={PERSONA_ID} filter in knowledge scope. "
|
||||
f"Got: {knowledge_filter}"
|
||||
)
|
||||
|
||||
def test_user_file_ids_only_creates_knowledge_filter(self) -> None:
|
||||
"""
|
||||
Tests that when only user_file_ids are provided (no attached_documents or
|
||||
hierarchy_nodes), the assistant knowledge filter is still created with the
|
||||
user file IDs.
|
||||
"""
|
||||
# Precondition
|
||||
|
||||
def test_knowledge_filter_with_document_sets_and_persona_filter(self) -> None:
|
||||
"""document_sets and persona_id_filter should be OR'd together in
|
||||
the knowledge scope filter."""
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[DocumentSource.USER_FILE],
|
||||
user_file_ids=[USER_FILE_ID],
|
||||
source_types=[],
|
||||
attached_document_ids=None,
|
||||
hierarchy_node_ids=None,
|
||||
persona_id_filter=PERSONA_ID,
|
||||
document_sets=["engineering"],
|
||||
)
|
||||
|
||||
# Postcondition: Find filter that contains our user file ID
|
||||
user_file_filter_found = False
|
||||
knowledge_filter = None
|
||||
for clause in filter_clauses:
|
||||
clause_str = str(clause)
|
||||
if str(USER_FILE_ID) in clause_str:
|
||||
user_file_filter_found = True
|
||||
break
|
||||
if "bool" in clause and "should" in clause["bool"]:
|
||||
if clause["bool"].get("minimum_should_match") == 1:
|
||||
knowledge_filter = clause
|
||||
break
|
||||
|
||||
assert (
|
||||
user_file_filter_found
|
||||
), f"Expected user_file_id {USER_FILE_ID} to be in the filter clauses. Got: {filter_clauses}"
|
||||
knowledge_filter is not None
|
||||
), "Expected knowledge filter when document_sets is provided"
|
||||
|
||||
def test_no_separate_user_file_filter_when_assistant_has_knowledge(self) -> None:
|
||||
"""
|
||||
Tests that user_file_ids are NOT added as a separate AND filter when the
|
||||
assistant has other knowledge attached (attached_documents or hierarchy_nodes).
|
||||
"""
|
||||
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
|
||||
user_file_ids=[USER_FILE_ID],
|
||||
attached_document_ids=[ATTACHED_DOCUMENT_ID],
|
||||
hierarchy_node_ids=None,
|
||||
)
|
||||
|
||||
# Postcondition: Count how many times user_file_id appears in filter clauses
|
||||
# It should appear exactly once (in the knowledge filter), not twice
|
||||
user_file_id_str = str(USER_FILE_ID)
|
||||
occurrences = 0
|
||||
for clause in filter_clauses:
|
||||
if user_file_id_str in str(clause):
|
||||
occurrences += 1
|
||||
|
||||
assert occurrences == 1, (
|
||||
f"Expected user_file_id to appear exactly once in filter clauses "
|
||||
f"(inside the assistant knowledge filter), but found {occurrences} "
|
||||
f"occurrences. This suggests user_file_ids is being added as both a "
|
||||
f"separate AND filter and inside the knowledge filter. "
|
||||
f"Filter clauses: {filter_clauses}"
|
||||
)
|
||||
|
||||
def test_multiple_user_files_all_included_in_filter(self) -> None:
|
||||
"""
|
||||
Tests that when multiple user files are attached to an assistant,
|
||||
all of them are included in the filter.
|
||||
"""
|
||||
# Precondition
|
||||
user_file_ids = [
|
||||
UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b"),
|
||||
UUID("7be95f56-5561-517d-ae47-acd6f85bdb7c"),
|
||||
UUID("8cf06a67-6672-628e-bf58-ade7a96cec8d"),
|
||||
]
|
||||
|
||||
filter_clauses = _get_search_filters(
|
||||
source_types=[DocumentSource.USER_FILE],
|
||||
user_file_ids=user_file_ids,
|
||||
attached_document_ids=[ATTACHED_DOCUMENT_ID],
|
||||
hierarchy_node_ids=None,
|
||||
)
|
||||
|
||||
# Postcondition: All user file IDs should be in the filter
|
||||
filter_str = str(filter_clauses)
|
||||
for user_file_id in user_file_ids:
|
||||
assert (
|
||||
str(user_file_id) in filter_str
|
||||
), f"Expected user_file_id {user_file_id} to be in the filter clauses"
|
||||
filter_str = str(knowledge_filter)
|
||||
assert (
|
||||
"engineering" in filter_str
|
||||
), "Expected document_set 'engineering' in knowledge filter"
|
||||
assert (
|
||||
str(PERSONA_ID) in filter_str
|
||||
), f"Expected persona_id_filter {PERSONA_ID} in knowledge filter"
|
||||
|
||||
@@ -1640,3 +1640,275 @@ class TestOpenSearchClient:
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
|
||||
def test_keyword_search(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests keyword search with filters for ACL, hidden documents, and tenant
|
||||
isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
# Tests that we don't return documents that don't match keywords at
|
||||
# all, even if they match filters.
|
||||
"private-but-not-relevant-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-but-not-relevant-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="This text should not match the query at all",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
# Should not match private-but-not-relevant-doc-user-a.
|
||||
query_text = "document content"
|
||||
search_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query_text,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# This should be the highest-ranked result, as a higher percentage of
|
||||
# the content matches the query.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
# Make sure there is some kind of match highlight.
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
assert results[1].score < results[0].score
|
||||
|
||||
def test_semantic_search(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests semantic search with filters for ACL, hidden documents, and tenant
|
||||
isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc": _create_test_document_chunk(
|
||||
document_id="public-doc",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
# Make this identical to the query vector to test that this
|
||||
# result is returned first.
|
||||
content_vector=_generate_test_vector(0.6),
|
||||
),
|
||||
"hidden-doc": _create_test_document_chunk(
|
||||
document_id="hidden-doc",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
hidden=True,
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
"private-doc-user-a": _create_test_document_chunk(
|
||||
document_id="private-doc-user-a",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-a@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
# Make this different from the query vector to test that this
|
||||
# result is returned second.
|
||||
content_vector=_generate_test_vector(0.5),
|
||||
),
|
||||
"private-doc-user-b": _create_test_document_chunk(
|
||||
document_id="private-doc-user-b",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 987-65-4321",
|
||||
hidden=False,
|
||||
tenant_state=tenant_x,
|
||||
document_access=DocumentAccess.build(
|
||||
user_emails=["user-b@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
),
|
||||
),
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc, tenant_state=doc.tenant_id)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
test_client.refresh_index()
|
||||
|
||||
query_vector = _generate_test_vector(0.6)
|
||||
search_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_vector,
|
||||
num_hits=5,
|
||||
tenant_state=tenant_x,
|
||||
# The user should only be able to see their private docs. tenant_id
|
||||
# in this object is not relevant.
|
||||
index_filters=IndexFilters(
|
||||
access_control_list=[prefix_user_email("user-a@example.com")],
|
||||
tenant_id=None,
|
||||
),
|
||||
include_hidden=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
# Should only get the public, non-hidden document, and the private
|
||||
# document for which the user has access.
|
||||
assert len(results) == 2
|
||||
# We explicitly expect this to be the highest-ranked result.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[0].score == 1.0
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert 0.0 < results[1].score < 1.0
|
||||
|
||||
@@ -52,7 +52,7 @@ def _create_test_persona_with_mcp_tool(
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_listed=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
@@ -368,9 +368,10 @@ class TestMCPPassThroughOAuth:
|
||||
def mock_call_mcp_tool(
|
||||
server_url: str, # noqa: ARG001
|
||||
tool_name: str, # noqa: ARG001
|
||||
kwargs: dict[str, Any], # noqa: ARG001
|
||||
arguments: dict[str, Any], # noqa: ARG001
|
||||
connection_headers: dict[str, str],
|
||||
transport: MCPTransport, # noqa: ARG001
|
||||
auth: Any = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
captured_headers.update(connection_headers)
|
||||
return mocked_response
|
||||
|
||||
@@ -62,7 +62,7 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_listed=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -53,7 +53,7 @@ class PersonaManager:
|
||||
label_ids=label_ids or [],
|
||||
user_file_ids=user_file_ids or [],
|
||||
display_priority=display_priority,
|
||||
featured=featured,
|
||||
is_featured=featured,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@@ -79,7 +79,7 @@ class PersonaManager:
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
featured=featured,
|
||||
is_featured=featured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -122,7 +122,7 @@ class PersonaManager:
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
featured=featured if featured is not None else persona.featured,
|
||||
is_featured=featured if featured is not None else persona.is_featured,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
@@ -152,7 +152,7 @@ class PersonaManager:
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=[label["id"] for label in updated_persona_data["labels"]],
|
||||
featured=updated_persona_data["featured"],
|
||||
is_featured=updated_persona_data["is_featured"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -205,9 +205,13 @@ class PersonaManager:
|
||||
mismatches.append(
|
||||
("is_public", persona.is_public, fetched_persona.is_public)
|
||||
)
|
||||
if fetched_persona.featured != persona.featured:
|
||||
if fetched_persona.is_featured != persona.is_featured:
|
||||
mismatches.append(
|
||||
("featured", persona.featured, fetched_persona.featured)
|
||||
(
|
||||
"is_featured",
|
||||
persona.is_featured,
|
||||
fetched_persona.is_featured,
|
||||
)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
|
||||
@@ -169,7 +169,7 @@ class DATestPersona(BaseModel):
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
featured: bool = False
|
||||
is_featured: bool = False
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
|
||||
@@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
@@ -28,6 +29,9 @@ _BACKEND_DIR = os.path.normpath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
|
||||
)
|
||||
|
||||
_DROP_SCHEMA_MAX_RETRIES = 3
|
||||
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -50,6 +54,39 @@ def _run_script(
|
||||
)
|
||||
|
||||
|
||||
def _force_drop_schema(engine: Engine, schema: str) -> None:
|
||||
"""Terminate backends using *schema* then drop it, retrying on deadlock.
|
||||
|
||||
Background Celery workers may discover test schemas (they match the
|
||||
``tenant_`` prefix) and hold locks on tables inside them. A bare
|
||||
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
|
||||
first kill their connections and retry if we still hit a deadlock.
|
||||
"""
|
||||
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT pg_terminate_backend(l.pid)
|
||||
FROM pg_locks l
|
||||
JOIN pg_class c ON c.oid = l.relation
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = :schema
|
||||
AND l.pid != pg_backend_pid()
|
||||
"""
|
||||
),
|
||||
{"schema": schema},
|
||||
)
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
return
|
||||
except Exception:
|
||||
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
|
||||
raise
|
||||
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -104,9 +141,7 @@ def tenant_schema_at_head(
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
_force_drop_schema(engine, schema)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -123,9 +158,7 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
_force_drop_schema(engine, schema)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -150,9 +183,7 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
_force_drop_schema(engine, schema)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Integration tests for the "Last Indexed" time displayed on both the
|
||||
per-connector detail page and the all-connectors listing page.
|
||||
|
||||
Expected behavior: "Last Indexed" = time_started of the most recent
|
||||
successful index attempt for the cc pair, regardless of pagination.
|
||||
|
||||
Edge cases:
|
||||
1. First page of index attempts is entirely errors — last_indexed should
|
||||
still reflect the older successful attempt beyond page 1.
|
||||
2. Credential swap — successful attempts, then failures after a
|
||||
"credential change"; last_indexed should reflect the most recent
|
||||
successful attempt.
|
||||
3. Mix of statuses — only the most recent successful attempt matters.
|
||||
4. COMPLETED_WITH_ERRORS counts as a success for last_indexed purposes.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.server.documents.models import CCPairFullInfo
|
||||
from onyx.server.documents.models import ConnectorIndexingStatusLite
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _wait_for_real_success(
|
||||
cc_pair: DATestCCPair,
|
||||
admin: DATestUser,
|
||||
) -> None:
|
||||
"""Wait for the initial index attempt to complete successfully."""
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair,
|
||||
after=datetime(2000, 1, 1, tzinfo=timezone.utc),
|
||||
user_performing_action=admin,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
|
||||
def _get_detail(cc_pair_id: int, admin: DATestUser) -> CCPairFullInfo:
|
||||
result = CCPairManager.get_single(cc_pair_id, admin)
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
|
||||
def _get_listing(cc_pair_id: int, admin: DATestUser) -> ConnectorIndexingStatusLite:
|
||||
result = CCPairManager.get_indexing_status_by_id(cc_pair_id, admin)
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
|
||||
def test_last_indexed_first_page_all_errors(reset: None) -> None: # noqa: ARG001
|
||||
"""When the first page of index attempts is entirely errors but an
|
||||
older successful attempt exists, both the detail page and the listing
|
||||
page should still show the time of that successful attempt.
|
||||
|
||||
The detail page UI uses page size 8. We insert 10 failed attempts
|
||||
more recent than the initial success to push the success off page 1.
|
||||
"""
|
||||
admin = UserManager.create(name="admin_first_page_errors")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
|
||||
_wait_for_real_success(cc_pair, admin)
|
||||
|
||||
# Baseline: last_success should be set from the initial successful run
|
||||
listing_before = _get_listing(cc_pair.id, admin)
|
||||
assert listing_before.last_success is not None
|
||||
|
||||
# 10 recent failures push the success off page 1
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=10,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="simulated failure",
|
||||
base_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair.id, admin)
|
||||
listing = _get_listing(cc_pair.id, admin)
|
||||
|
||||
assert (
|
||||
detail.last_indexed is not None
|
||||
), "Detail page last_indexed is None even though a successful attempt exists"
|
||||
assert (
|
||||
listing.last_success is not None
|
||||
), "Listing page last_success is None even though a successful attempt exists"
|
||||
|
||||
# Both surfaces must agree
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
|
||||
|
||||
def test_last_indexed_credential_swap_scenario(reset: None) -> None: # noqa: ARG001
|
||||
"""Perform an actual credential swap: create connector + cred1 (cc_pair_1),
|
||||
wait for success, then associate a new cred2 with the same connector
|
||||
(cc_pair_2), wait for that to succeed, and inject failures on cc_pair_2.
|
||||
|
||||
cc_pair_2's last_indexed must reflect cc_pair_2's own success, not
|
||||
cc_pair_1's older one. Both the detail page and listing page must agree.
|
||||
"""
|
||||
admin = UserManager.create(name="admin_cred_swap")
|
||||
|
||||
connector = ConnectorManager.create(user_performing_action=admin)
|
||||
cred1 = CredentialManager.create(user_performing_action=admin)
|
||||
cc_pair_1 = CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=cred1.id,
|
||||
user_performing_action=admin,
|
||||
)
|
||||
_wait_for_real_success(cc_pair_1, admin)
|
||||
|
||||
cred2 = CredentialManager.create(user_performing_action=admin, name="swapped-cred")
|
||||
cc_pair_2 = CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=cred2.id,
|
||||
user_performing_action=admin,
|
||||
)
|
||||
_wait_for_real_success(cc_pair_2, admin)
|
||||
|
||||
listing_after_swap = _get_listing(cc_pair_2.id, admin)
|
||||
assert listing_after_swap.last_success is not None
|
||||
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=10,
|
||||
cc_pair_id=cc_pair_2.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="credential expired",
|
||||
base_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair_2.id, admin)
|
||||
listing = _get_listing(cc_pair_2.id, admin)
|
||||
|
||||
assert detail.last_indexed is not None
|
||||
assert listing.last_success is not None
|
||||
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
|
||||
|
||||
def test_last_indexed_mixed_statuses(reset: None) -> None: # noqa: ARG001
|
||||
"""Mix of in_progress, failed, and successful attempts. Only the most
|
||||
recent successful attempt's time matters."""
|
||||
admin = UserManager.create(name="admin_mixed")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
|
||||
_wait_for_real_success(cc_pair, admin)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Success 5 hours ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=1,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
base_time=now - timedelta(hours=5),
|
||||
)
|
||||
|
||||
# Failures 3 hours ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=3,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="transient failure",
|
||||
base_time=now - timedelta(hours=3),
|
||||
)
|
||||
|
||||
# In-progress 1 hour ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=1,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.IN_PROGRESS,
|
||||
base_time=now - timedelta(hours=1),
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair.id, admin)
|
||||
listing = _get_listing(cc_pair.id, admin)
|
||||
|
||||
assert detail.last_indexed is not None
|
||||
assert listing.last_success is not None
|
||||
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
|
||||
|
||||
def test_last_indexed_completed_with_errors(reset: None) -> None: # noqa: ARG001
|
||||
"""COMPLETED_WITH_ERRORS is treated as a successful attempt (matching
|
||||
IndexingStatus.is_successful()). When it is the most recent "success"
|
||||
and later attempts all failed, both surfaces should reflect its time."""
|
||||
admin = UserManager.create(name="admin_completed_errors")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
|
||||
_wait_for_real_success(cc_pair, admin)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# COMPLETED_WITH_ERRORS 2 hours ago
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=1,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.COMPLETED_WITH_ERRORS,
|
||||
base_time=now - timedelta(hours=2),
|
||||
)
|
||||
|
||||
# 10 failures after — push everything else off page 1
|
||||
IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=10,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.FAILED,
|
||||
error_msg="post-partial failure",
|
||||
base_time=now,
|
||||
)
|
||||
|
||||
detail = _get_detail(cc_pair.id, admin)
|
||||
listing = _get_listing(cc_pair.id, admin)
|
||||
|
||||
assert (
|
||||
detail.last_indexed is not None
|
||||
), "COMPLETED_WITH_ERRORS should count as a success for last_indexed"
|
||||
assert (
|
||||
listing.last_success is not None
|
||||
), "COMPLETED_WITH_ERRORS should count as a success for last_success"
|
||||
|
||||
assert detail.last_indexed == listing.last_success, (
|
||||
f"Detail last_indexed={detail.last_indexed} != "
|
||||
f"listing last_success={listing.last_success}"
|
||||
)
|
||||
@@ -35,8 +35,8 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
is_visible=True,
|
||||
featured=False,
|
||||
is_listed=True,
|
||||
is_featured=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, builtin_persona, featured, deleted
|
||||
SELECT id, name, builtin_persona, is_featured, deleted
|
||||
FROM persona
|
||||
WHERE builtin_persona = true
|
||||
ORDER BY id
|
||||
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert default[0] == 0, "Default assistant should have ID 0"
|
||||
assert default[1] == "Assistant", "Should be named 'Assistant'"
|
||||
assert default[2] is True, "Should be builtin"
|
||||
assert default[3] is True, "Should be featured"
|
||||
assert default[3] is True, "Should be is_featured"
|
||||
assert default[4] is False, "Should not be deleted"
|
||||
|
||||
# Check tools are properly associated
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.configs.constants import ANONYMOUS_USER_UUID
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from tests.integration.common_utils.reset import downgrade_postgres
|
||||
@@ -237,7 +238,6 @@ def test_jira_connector_migration() -> None:
|
||||
upgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="da42808081e3"
|
||||
)
|
||||
|
||||
# Verify the upgrade was applied correctly
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = db_session.execute(
|
||||
@@ -322,3 +322,165 @@ def test_jira_connector_migration() -> None:
|
||||
== "https://example.atlassian.net/projects/TEST"
|
||||
)
|
||||
assert config_2["batch_size"] == 50
|
||||
|
||||
|
||||
def test_anonymous_user_migration_dedupes_null_notifications() -> None:
|
||||
downgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="base", clear_data=True
|
||||
)
|
||||
upgrade_postgres(
|
||||
database="postgres",
|
||||
config_name="alembic",
|
||||
revision="f7ca3e2f45d9",
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO notification (
|
||||
id,
|
||||
notif_type,
|
||||
user_id,
|
||||
dismissed,
|
||||
last_shown,
|
||||
first_shown,
|
||||
title,
|
||||
description,
|
||||
additional_data
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
1,
|
||||
'RELEASE_NOTES',
|
||||
NULL,
|
||||
FALSE,
|
||||
NOW(),
|
||||
NOW(),
|
||||
'Onyx v2.10.0 is available!',
|
||||
'Check out what''s new in v2.10.0',
|
||||
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
|
||||
),
|
||||
(
|
||||
2,
|
||||
'RELEASE_NOTES',
|
||||
NULL,
|
||||
FALSE,
|
||||
NOW(),
|
||||
NOW(),
|
||||
'Onyx v2.10.0 is available!',
|
||||
'Check out what''s new in v2.10.0',
|
||||
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
upgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="e7f8a9b0c1d2"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
notifications = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, user_id
|
||||
FROM notification
|
||||
ORDER BY id
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
anonymous_user = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, email, role
|
||||
FROM "user"
|
||||
WHERE id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
).fetchone()
|
||||
|
||||
assert len(notifications) == 1
|
||||
assert notifications[0].id == 2 # Higher id wins when timestamps are equal
|
||||
assert str(notifications[0].user_id) == ANONYMOUS_USER_UUID
|
||||
assert anonymous_user is not None
|
||||
assert anonymous_user.email == "anonymous@onyx.app"
|
||||
assert anonymous_user.role == "LIMITED"
|
||||
|
||||
|
||||
def test_anonymous_user_migration_collision_with_existing_anonymous_notification() -> (
|
||||
None
|
||||
):
|
||||
"""Test that a NULL-owned notification that collides with an already-existing
|
||||
anonymous-owned notification is removed during migration."""
|
||||
downgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="base", clear_data=True
|
||||
)
|
||||
upgrade_postgres(
|
||||
database="postgres",
|
||||
config_name="alembic",
|
||||
revision="f7ca3e2f45d9",
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Create the anonymous user early so we can insert a notification owned by it
|
||||
db_session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO "user" (id, email, hashed_password, is_active, is_superuser, is_verified, role)
|
||||
VALUES (:id, 'anonymous@onyx.app', '', TRUE, FALSE, TRUE, 'LIMITED')
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
# Insert an anonymous-owned notification (already migrated in a prior partial run)
|
||||
db_session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO notification (
|
||||
id, notif_type, user_id, dismissed, last_shown, first_shown,
|
||||
title, description, additional_data
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
1, 'RELEASE_NOTES', :user_id, FALSE, NOW(), NOW(),
|
||||
'Onyx v2.10.0 is available!',
|
||||
'Check out what''s new in v2.10.0',
|
||||
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
|
||||
),
|
||||
(
|
||||
2, 'RELEASE_NOTES', NULL, FALSE, NOW(), NOW(),
|
||||
'Onyx v2.10.0 is available!',
|
||||
'Check out what''s new in v2.10.0',
|
||||
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
upgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="e7f8a9b0c1d2"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
notifications = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, user_id
|
||||
FROM notification
|
||||
ORDER BY id
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Only the original anonymous-owned notification should remain;
|
||||
# the NULL-owned duplicate should have been deleted
|
||||
assert len(notifications) == 1
|
||||
assert notifications[0].id == 1
|
||||
assert str(notifications[0].user_id) == ANONYMOUS_USER_UUID
|
||||
|
||||
@@ -33,8 +33,8 @@ def test_unified_assistant(
|
||||
"search, web browsing, and image generation"
|
||||
in unified_assistant.description.lower()
|
||||
)
|
||||
assert unified_assistant.featured is True
|
||||
assert unified_assistant.is_visible is True
|
||||
assert unified_assistant.is_featured is True
|
||||
assert unified_assistant.is_listed is True
|
||||
|
||||
# Verify tools
|
||||
tools = unified_assistant.tools
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -139,12 +141,12 @@ def test_chat_history_csv_export(
|
||||
assert headers["Content-Type"] == "text/csv; charset=utf-8"
|
||||
assert "Content-Disposition" in headers
|
||||
|
||||
# Verify CSV content
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 3 # Header + 2 QA pairs
|
||||
assert "chat_session_id" in csv_content
|
||||
assert "user_message" in csv_content
|
||||
assert "ai_response" in csv_content
|
||||
# Use csv.reader to properly handle newlines inside quoted fields
|
||||
csv_rows = list(csv.reader(io.StringIO(csv_content)))
|
||||
assert len(csv_rows) == 3 # Header + 2 QA pairs
|
||||
assert csv_rows[0][0] == "chat_session_id"
|
||||
assert "user_message" in csv_rows[0]
|
||||
assert "ai_response" in csv_rows[0]
|
||||
assert "What was the Q1 revenue?" in csv_content
|
||||
assert "What about Q2 revenue?" in csv_content
|
||||
|
||||
@@ -156,5 +158,5 @@ def test_chat_history_csv_export(
|
||||
end_time=past_end,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 1 # Only header, no data rows
|
||||
csv_rows = list(csv.reader(io.StringIO(csv_content)))
|
||||
assert len(csv_rows) == 1 # Only header, no data rows
|
||||
|
||||
@@ -86,7 +86,7 @@ async def test_get_or_create_user_skips_inactive(
|
||||
"""Inactive users should not be re-authenticated via JWT."""
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda *_a, **_kw: None)
|
||||
|
||||
email = "inactive@example.com"
|
||||
payload: dict[str, Any] = {"email": email}
|
||||
@@ -126,7 +126,7 @@ async def test_get_or_create_user_handles_race_conditions(
|
||||
"""If provisioning races, newly inactive users should still be blocked."""
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda *_a, **_kw: None)
|
||||
|
||||
email = "race@example.com"
|
||||
payload: dict[str, Any] = {"email": email}
|
||||
@@ -182,7 +182,7 @@ async def test_get_or_create_user_provisions_new_user(
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", False)
|
||||
monkeypatch.setattr(users_module, "generate_password", lambda: "TempPass123!")
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda *_a, **_kw: None)
|
||||
|
||||
recorded: dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# Note: Only async test methods are marked with @pytest.mark.asyncio individually
|
||||
# to avoid warnings on synchronous tests
|
||||
@@ -89,11 +89,11 @@ class TestDisposableEmailValidation:
|
||||
user_manager = UserManager(MagicMock())
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
await user_manager.create(mock_user_create)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert "Disposable email" in str(exc.value.detail)
|
||||
assert "Disposable email" in exc.value.detail
|
||||
# Verify we never got to tenant provisioning
|
||||
mock_fetch_ee.assert_not_called()
|
||||
|
||||
@@ -138,7 +138,9 @@ class TestDisposableEmailValidation:
|
||||
pass # We just want to verify domain check passed
|
||||
|
||||
# Verify domain validation was called
|
||||
mock_verify_domain.assert_called_once_with(mock_user_create.email)
|
||||
mock_verify_domain.assert_called_once_with(
|
||||
mock_user_create.email, is_registration=True
|
||||
)
|
||||
|
||||
|
||||
class TestMultiTenantInviteLogic:
|
||||
@@ -331,7 +333,7 @@ class TestSAMLOIDCBehavior:
|
||||
mock_get_invited.return_value = ["allowed@example.com"]
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_is_invited("newuser@example.com")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
@@ -385,7 +387,7 @@ class TestWhitelistBehavior:
|
||||
mock_get_invited.return_value = ["allowed@example.com"]
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_is_invited("notallowed@example.com")
|
||||
|
||||
assert exc.value.status_code == 403
|
||||
@@ -420,7 +422,7 @@ class TestSeatLimitEnforcement:
|
||||
"onyx.auth.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_a, **_kw: seat_result,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
enforce_seat_limit(MagicMock())
|
||||
|
||||
assert exc.value.status_code == 402
|
||||
@@ -490,7 +492,9 @@ class TestCaseInsensitiveEmailMatching:
|
||||
pass
|
||||
|
||||
# Verify flow
|
||||
mock_verify_domain.assert_called_once_with(user_create.email)
|
||||
mock_verify_domain.assert_called_once_with(
|
||||
user_create.email, is_registration=True
|
||||
)
|
||||
|
||||
@patch("onyx.auth.users.is_disposable_email")
|
||||
@patch("onyx.auth.users.verify_email_domain")
|
||||
@@ -540,5 +544,7 @@ class TestCaseInsensitiveEmailMatching:
|
||||
pass
|
||||
|
||||
# Verify flow
|
||||
mock_verify_domain.assert_called_once_with(mock_user_create.email)
|
||||
mock_verify_domain.assert_called_once_with(
|
||||
mock_user_create.email, is_registration=True
|
||||
)
|
||||
mock_verify_invited.assert_called_once() # Existing tenant = invite needed
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import onyx.auth.users as users
|
||||
from onyx.auth.users import verify_email_domain
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_case_insensitive_match(
|
||||
@@ -21,7 +21,7 @@ def test_verify_email_domain_rejects_non_whitelisted_domain(
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", ["example.com"], raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_domain("user@another.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "Email domain is not valid" in exc.value.detail
|
||||
@@ -32,7 +32,7 @@ def test_verify_email_domain_invalid_email_format(
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", ["example.com"], raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_domain("userexample.com") # missing '@'
|
||||
assert exc.value.status_code == 400
|
||||
assert "Email is not valid" in exc.value.detail
|
||||
@@ -44,10 +44,10 @@ def test_verify_email_domain_rejects_plus_addressing(
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_domain("user+tag@gmail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "'+'" in str(exc.value.detail)
|
||||
assert "'+'" in exc.value.detail
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_plus_for_onyx_app(
|
||||
@@ -60,13 +60,53 @@ def test_verify_email_domain_allows_plus_for_onyx_app(
|
||||
verify_email_domain("user+tag@onyx.app")
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_dotted_gmail_on_registration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_domain("first.last@gmail.com", is_registration=True)
|
||||
assert exc.value.status_code == 400
|
||||
assert "'.'" in exc.value.detail
|
||||
|
||||
|
||||
def test_verify_email_domain_dotted_gmail_allowed_when_not_registration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
# Existing user signing in — should not be blocked
|
||||
verify_email_domain("first.last@gmail.com", is_registration=False)
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_dotted_non_gmail_on_registration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
verify_email_domain("first.last@example.com", is_registration=True)
|
||||
|
||||
|
||||
def test_verify_email_domain_dotted_gmail_allowed_when_not_cloud(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.BASIC, raising=False)
|
||||
|
||||
verify_email_domain("first.last@gmail.com", is_registration=True)
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_googlemail(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_domain("user@googlemail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "gmail.com" in str(exc.value.detail)
|
||||
assert "gmail.com" in exc.value.detail
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import onyx.auth.users as users
|
||||
from onyx.auth.users import verify_email_is_invited
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
@pytest.mark.parametrize("auth_type", [AuthType.SAML, AuthType.OIDC])
|
||||
@@ -35,7 +35,7 @@ def test_verify_email_is_invited_enforced_for_basic_auth(
|
||||
raising=False,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
with pytest.raises(OnyxError) as exc:
|
||||
verify_email_is_invited("newuser@example.com")
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class TestExtractContextFiles:
|
||||
|
||||
class TestSearchFilterDetermination:
|
||||
"""Verify that determine_search_params correctly resolves
|
||||
search_project_id, search_persona_id, and search_usage based on
|
||||
project_id_filter, persona_id_filter, and search_usage based on
|
||||
the extraction result and the precedence rule.
|
||||
"""
|
||||
|
||||
@@ -353,8 +353,8 @@ class TestSearchFilterDetermination:
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.project_id_filter is None
|
||||
assert result.persona_id_filter is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_files_overflow_persona_filter(self) -> None:
|
||||
@@ -364,8 +364,8 @@ class TestSearchFilterDetermination:
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(use_as_search_filter=True),
|
||||
)
|
||||
assert result.search_persona_id == 42
|
||||
assert result.search_project_id is None
|
||||
assert result.persona_id_filter == 42
|
||||
assert result.project_id_filter is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_no_files_no_project_leak(self) -> None:
|
||||
@@ -375,8 +375,8 @@ class TestSearchFilterDetermination:
|
||||
project_id=99,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.search_persona_id is None
|
||||
assert result.project_id_filter is None
|
||||
assert result.persona_id_filter is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_files_fit_disables_search(self) -> None:
|
||||
@@ -389,7 +389,7 @@ class TestSearchFilterDetermination:
|
||||
uncapped_token_count=100,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.project_id_filter is None
|
||||
assert result.search_usage == SearchToolUsage.DISABLED
|
||||
|
||||
def test_default_persona_project_files_overflow_enables_search(self) -> None:
|
||||
@@ -402,8 +402,8 @@ class TestSearchFilterDetermination:
|
||||
uncapped_token_count=7000,
|
||||
),
|
||||
)
|
||||
assert result.search_project_id == 99
|
||||
assert result.search_persona_id is None
|
||||
assert result.project_id_filter == 99
|
||||
assert result.persona_id_filter is None
|
||||
assert result.search_usage == SearchToolUsage.ENABLED
|
||||
|
||||
def test_default_persona_no_project_auto(self) -> None:
|
||||
@@ -413,7 +413,7 @@ class TestSearchFilterDetermination:
|
||||
project_id=None,
|
||||
extracted_context_files=self._make_context(),
|
||||
)
|
||||
assert result.search_project_id is None
|
||||
assert result.project_id_filter is None
|
||||
assert result.search_usage == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_no_files_disables_search(self) -> None:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
@@ -11,12 +10,10 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
|
||||
|
||||
class IncompleteSpec(HookPointSpec):
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
# missing display_name, description, etc.
|
||||
# missing display_name, description, payload_model, response_model, etc.
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
class _Payload(BaseModel):
|
||||
pass
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
payload_model = _Payload
|
||||
response_model = _Payload
|
||||
|
||||
541
backend/tests/unit/onyx/hooks/test_executor.py
Normal file
541
backend/tests/unit/onyx/hooks/test_executor.py
Normal file
@@ -0,0 +1,541 @@
|
||||
"""Unit tests for the hook executor."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
|
||||
|
||||
|
||||
def _make_hook(
|
||||
*,
|
||||
is_active: bool = True,
|
||||
endpoint_url: str | None = "https://hook.example.com/query",
|
||||
api_key: MagicMock | None = None,
|
||||
timeout_seconds: float = 5.0,
|
||||
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
|
||||
hook_id: int = 1,
|
||||
is_reachable: bool | None = None,
|
||||
) -> MagicMock:
|
||||
hook = MagicMock()
|
||||
hook.is_active = is_active
|
||||
hook.endpoint_url = endpoint_url
|
||||
hook.api_key = api_key
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
hook.id = hook_id
|
||||
hook.fail_strategy = fail_strategy
|
||||
hook.is_reachable = is_reachable
|
||||
return hook
|
||||
|
||||
|
||||
def _make_api_key(value: str) -> MagicMock:
|
||||
api_key = MagicMock()
|
||||
api_key.get_value.return_value = value
|
||||
return api_key
|
||||
|
||||
|
||||
def _make_response(
|
||||
*,
|
||||
status_code: int = 200,
|
||||
json_return: Any = _RESPONSE_PAYLOAD,
|
||||
json_side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a response mock with controllable json() behaviour."""
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
if json_side_effect is not None:
|
||||
response.json.side_effect = json_side_effect
|
||||
else:
|
||||
response.json.return_value = json_return
|
||||
return response
|
||||
|
||||
|
||||
def _setup_client(
|
||||
mock_client_cls: MagicMock,
|
||||
*,
|
||||
response: MagicMock | None = None,
|
||||
side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Wire up the httpx.Client mock and return the inner client.
|
||||
|
||||
If side_effect is an httpx.HTTPStatusError, it is raised from
|
||||
raise_for_status() (matching real httpx behaviour) and post() returns a
|
||||
response mock with the matching status_code set. All other exceptions are
|
||||
raised directly from post().
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
if isinstance(side_effect, httpx.HTTPStatusError):
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = side_effect.response.status_code
|
||||
error_response.raise_for_status.side_effect = side_effect
|
||||
mock_client.post = MagicMock(return_value=error_response)
|
||||
else:
|
||||
mock_client.post = MagicMock(
|
||||
side_effect=side_effect, return_value=response if not side_effect else None
|
||||
)
|
||||
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Early-exit guards (no HTTP call, no DB writes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hooks_available,hook",
|
||||
[
|
||||
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(False, None, id="hooks_not_available"),
|
||||
pytest.param(True, None, id="hook_not_found"),
|
||||
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
],
|
||||
)
|
||||
def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session: MagicMock,
|
||||
hooks_available: bool,
|
||||
hook: MagicMock | None,
|
||||
) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSkipped)
|
||||
mock_update.assert_not_called()
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Successful HTTP call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
_, update_kwargs = mock_update.call_args
|
||||
assert update_kwargs["is_reachable"] is True
|
||||
mock_log.assert_not_called()
|
||||
|
||||
|
||||
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
|
||||
"""Deduplication guard: a hook already at is_reachable=True that succeeds
|
||||
must not trigger a DB write."""
|
||||
hook = _make_hook(is_reachable=True)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(json_return=["unexpected", "list"]),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-dict" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
"""response.json() raising must be treated as failure with SOFT strategy.
|
||||
The server responded, so is_reachable is not updated."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response(
|
||||
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
|
||||
),
|
||||
)
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "non-JSON" in (log_kwargs["error_message"] or "")
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP failure paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception,fail_strategy,expected_type,expected_is_reachable",
|
||||
[
|
||||
# NetworkError → is_reachable=False
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="connect_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.ConnectError("refused"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="connect_error_hard",
|
||||
),
|
||||
# 401/403 → is_reachable=False (api_key revoked)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"401",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=401, text="Unauthorized"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
False,
|
||||
id="auth_401_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"403",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=403, text="Forbidden"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
False,
|
||||
id="auth_403_hard",
|
||||
),
|
||||
# TimeoutException → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="timeout_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.TimeoutException("timeout"),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="timeout_hard",
|
||||
),
|
||||
# Other HTTP errors → no is_reachable write (None)
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.SOFT,
|
||||
HookSoftFailed,
|
||||
None,
|
||||
id="http_status_error_soft",
|
||||
),
|
||||
pytest.param(
|
||||
httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=500, text="error"),
|
||||
),
|
||||
HookFailStrategy.HARD,
|
||||
OnyxError,
|
||||
None,
|
||||
id="http_status_error_hard",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_http_failure_paths(
|
||||
db_session: MagicMock,
|
||||
exception: Exception,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
expected_is_reachable: bool | None,
|
||||
) -> None:
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=exception)
|
||||
|
||||
if expected_type is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert isinstance(result, expected_type)
|
||||
|
||||
if expected_is_reachable is None:
|
||||
mock_update.assert_not_called()
|
||||
else:
|
||||
mock_update.assert_called_once()
|
||||
_, kwargs = mock_update.call_args
|
||||
assert kwargs["is_reachable"] is expected_is_reachable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization header
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_key_value,expect_auth_header",
|
||||
[
|
||||
pytest.param("secret-token", True, id="api_key_present"),
|
||||
pytest.param(None, False, id="api_key_absent"),
|
||||
],
|
||||
)
|
||||
def test_authorization_header(
|
||||
db_session: MagicMock,
|
||||
api_key_value: str | None,
|
||||
expect_auth_header: bool,
|
||||
) -> None:
|
||||
api_key = _make_api_key(api_key_value) if api_key_value else None
|
||||
hook = _make_hook(api_key=api_key)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
mock_client = _setup_client(mock_client_cls, response=_make_response())
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
_, call_kwargs = mock_client.post.call_args
|
||||
if expect_auth_header:
|
||||
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
|
||||
else:
|
||||
assert "Authorization" not in call_kwargs["headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist session failure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"http_exception,expected_result",
|
||||
[
|
||||
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
|
||||
],
|
||||
)
|
||||
def test_persist_session_failure_is_swallowed(
|
||||
db_session: MagicMock,
|
||||
http_exception: Exception | None,
|
||||
expected_result: Any,
|
||||
) -> None:
|
||||
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_session_with_current_tenant",
|
||||
side_effect=RuntimeError("DB unavailable"),
|
||||
),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
mock_client_cls,
|
||||
response=_make_response() if not http_exception else None,
|
||||
side_effect=http_exception,
|
||||
)
|
||||
|
||||
if expected_result is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
|
||||
"""is_reachable update failing (e.g. concurrent hook deletion) must not
|
||||
prevent the execution log from being written.
|
||||
|
||||
Simulates the production failure path: update_hook__no_commit raises
|
||||
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
|
||||
between the initial lookup and the reachable update.
|
||||
"""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.hooks.executor.update_hook__no_commit",
|
||||
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
|
||||
),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
mock_log.assert_called_once()
|
||||
@@ -37,18 +37,20 @@ def test_input_schema_query_is_string() -> None:
|
||||
|
||||
def test_input_schema_user_email_is_nullable() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert "null" in props["user_email"]["type"]
|
||||
# Pydantic v2 emits anyOf for nullable fields
|
||||
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
|
||||
|
||||
|
||||
def test_output_schema_query_is_required() -> None:
|
||||
def test_output_schema_query_is_optional() -> None:
|
||||
# query defaults to None (absent = reject); not required in the schema
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "query" in schema["required"]
|
||||
assert "query" not in schema.get("required", [])
|
||||
|
||||
|
||||
def test_output_schema_query_is_nullable() -> None:
|
||||
# null means "reject the query"
|
||||
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
|
||||
props = QueryProcessingSpec().output_schema["properties"]
|
||||
assert "null" in props["query"]["type"]
|
||||
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
|
||||
|
||||
|
||||
def test_output_schema_rejection_message_is_optional() -> None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user