mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-27 04:35:50 +00:00
Compare commits
223 Commits
refactor/c
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb8f16d6c | ||
|
|
21aa89badc | ||
|
|
be13aa1310 | ||
|
|
45d38c4906 | ||
|
|
8aab518532 | ||
|
|
da6ce10e86 | ||
|
|
aaf8253520 | ||
|
|
7c7f81b164 | ||
|
|
2d4a3c72e9 | ||
|
|
7c51712018 | ||
|
|
aa5614695d | ||
|
|
8d7255d3c4 | ||
|
|
d403498f48 | ||
|
|
9ef3095c17 | ||
|
|
a39e93a0cb | ||
|
|
46d73cdfee | ||
|
|
1e04ce78e0 | ||
|
|
f9b81c1725 | ||
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 | ||
|
|
5c5a6f943b | ||
|
|
d04128b8b1 | ||
|
|
bbebdf8f78 | ||
|
|
161279a2d5 | ||
|
|
e5ebb45a20 | ||
|
|
320ba9cb1b | ||
|
|
f2e8cb3114 | ||
|
|
43054a28ec | ||
|
|
dc74aa7b1f | ||
|
|
bd773191c2 | ||
|
|
66dbff41e6 | ||
|
|
1dcffe38bc | ||
|
|
c35e883564 | ||
|
|
fefcd58481 | ||
|
|
bdc89d9e3f | ||
|
|
f4d777b80d | ||
|
|
da4d57b5e3 | ||
|
|
dcdcd067bd | ||
|
|
8b15a29723 | ||
|
|
763853674f | ||
|
|
429b6f3465 | ||
|
|
37d5be1b40 | ||
|
|
8ab99dbb06 | ||
|
|
52799e9c7a | ||
|
|
aef009cc97 | ||
|
|
18d1ea1770 | ||
|
|
f336ad00f4 | ||
|
|
0558e687d9 | ||
|
|
784a99e24a | ||
|
|
da1f5a11f4 | ||
|
|
5633805890 | ||
|
|
0817b45ae1 | ||
|
|
af0e4bdebc | ||
|
|
4cd2320732 | ||
|
|
90a361f0e1 | ||
|
|
194efde97b | ||
|
|
d922a42262 | ||
|
|
f00c3a486e | ||
|
|
192080c9e4 | ||
|
|
c5787dc073 | ||
|
|
d424d6462c | ||
|
|
ecea86deb6 | ||
|
|
a5c1f50a8a | ||
|
|
4a04cfd486 | ||
|
|
f22e9628db | ||
|
|
255ba10af6 | ||
|
|
563202a080 | ||
|
|
1062dc0743 | ||
|
|
0826348568 | ||
|
|
375079136d | ||
|
|
82aad5e253 | ||
|
|
beb1c49c69 | ||
|
|
c4556515be | ||
|
|
a4387f230b | ||
|
|
d91e452658 | ||
|
|
dd274f8667 | ||
|
|
2c82f0da16 | ||
|
|
26101636f2 | ||
|
|
5e2c0c6cf4 | ||
|
|
33b64db498 | ||
|
|
b925cc1a56 | ||
|
|
bac4b7c945 | ||
|
|
6f6ef1e657 | ||
|
|
885c69f460 | ||
|
|
4b837303ff | ||
|
|
d856a9befb | ||
|
|
adade353c5 | ||
|
|
3cb6ec2f85 | ||
|
|
691eebf00a | ||
|
|
905b6633e6 | ||
|
|
fd088196ff | ||
|
|
cafbf5b8be | ||
|
|
1235181559 | ||
|
|
caa2e45632 | ||
|
|
9c62e03120 | ||
|
|
0937305064 | ||
|
|
e4c06570e3 | ||
|
|
78fc7c86d7 | ||
|
|
84d3aea847 | ||
|
|
00a404d3cd | ||
|
|
787cf90d96 | ||
|
|
15fe47adc5 | ||
|
|
29958f1a52 | ||
|
|
ac7f9838bc | ||
|
|
d0fa4b3319 | ||
|
|
3fb4fb422e | ||
|
|
ba5da22ea1 | ||
|
|
9909049047 | ||
|
|
c516aa3e3c | ||
|
|
5cc6220417 | ||
|
|
15da1e0a88 | ||
|
|
e9ff00890b | ||
|
|
67747a9d93 | ||
|
|
edfc51b439 | ||
|
|
ac4fba947e | ||
|
|
c142b2db02 | ||
|
|
fb7e7e4395 | ||
|
|
113f23398e | ||
|
|
5a8716026a | ||
|
|
3389140bfd | ||
|
|
13109e7b81 | ||
|
|
56ad457168 | ||
|
|
a81aea2afc | ||
|
|
7cb5c9c4a6 | ||
|
|
3520c58a22 | ||
|
|
bd9d1bfa27 | ||
|
|
14416cc3db | ||
|
|
d7fce14d26 | ||
|
|
39a8d8ed05 | ||
|
|
82f735a434 | ||
|
|
aadb58518b | ||
|
|
0755499e0f | ||
|
|
27aaf977a2 | ||
|
|
9f707f195e | ||
|
|
3e35570f70 | ||
|
|
53b1bf3b2c | ||
|
|
5a3fa6b648 | ||
|
|
fc6a37850b | ||
|
|
aa6fec3d58 | ||
|
|
efa6005e36 | ||
|
|
921bfc72f4 | ||
|
|
812603152d | ||
|
|
6779d8fbd7 | ||
|
|
2c9826e4a9 | ||
|
|
5b54687077 | ||
|
|
0f7e2ee674 | ||
|
|
ea466648d9 | ||
|
|
a402911ee6 | ||
|
|
7ae9ba807d | ||
|
|
1f79223c42 | ||
|
|
c0c2247d5a | ||
|
|
2989ceda41 | ||
|
|
c825f5eca6 | ||
|
|
a8965def79 | ||
|
|
59e1ad51ba | ||
|
|
0e70a8f826 | ||
|
|
0891737dfd | ||
|
|
5a20112670 | ||
|
|
584f2e2638 | ||
|
|
aa24b16ec1 | ||
|
|
50aa9d7df6 | ||
|
|
bfda586054 | ||
|
|
e04392fbb1 | ||
|
|
e46c6c5175 | ||
|
|
f59792b4ac | ||
|
|
973b9456e9 | ||
|
|
aa8d126513 | ||
|
|
a6da5add49 | ||
|
|
3356f90437 | ||
|
|
27c254ecf9 | ||
|
|
09678b3c8e | ||
|
|
ecdb962e24 | ||
|
|
63b9b91565 | ||
|
|
14770e6e90 | ||
|
|
14807d986a | ||
|
|
290eb98020 | ||
|
|
fe6fa3d034 | ||
|
|
2a60a02e0e | ||
|
|
3bcd666e90 | ||
|
|
684013732c | ||
|
|
367dcb8f8b | ||
|
|
59dfed0bc8 | ||
|
|
7a719b54bb | ||
|
|
25ef5ff010 | ||
|
|
53f9f042a1 | ||
|
|
3469f0c979 | ||
|
|
f688efbcd6 | ||
|
|
250658a8b2 | ||
|
|
5150ffc3e0 | ||
|
|
858c1dbe4a | ||
|
|
a8e7353227 | ||
|
|
343cda35cb | ||
|
|
1cbe47d85e | ||
|
|
221658132a | ||
|
|
fe8fb9eb75 | ||
|
|
f7925584b8 | ||
|
|
00b0e15ed7 | ||
|
|
c2968e3bfe | ||
|
|
978f0a9d35 | ||
|
|
410340fe37 | ||
|
|
a0545c7eb3 | ||
|
|
aa46a8bba2 | ||
|
|
cd5aaa0302 | ||
|
|
db33efaeaa | ||
|
|
c28d37dff8 | ||
|
|
696e72bcbb | ||
|
|
cfdac8083a | ||
|
|
781aab67fa | ||
|
|
b14d357d55 | ||
|
|
60bc1ce8a1 | ||
|
|
c89e82ee58 | ||
|
|
529ab8179f | ||
|
|
19716874b2 | ||
|
|
1ac3b8515d | ||
|
|
0d0c8580ca | ||
|
|
96a38dcc06 | ||
|
|
6fe72e5524 | ||
|
|
1f4ee4d550 |
1
.claude/skills
Symbolic link
1
.claude/skills
Symbolic link
@@ -0,0 +1 @@
|
||||
../.cursor/skills
|
||||
248
.cursor/skills/playwright/SKILL.md
Normal file
248
.cursor/skills/playwright/SKILL.md
Normal file
@@ -0,0 +1,248 @@
|
||||
---
|
||||
name: playwright-e2e-tests
|
||||
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
|
||||
---
|
||||
|
||||
# Playwright E2E Tests
|
||||
|
||||
## Project Layout
|
||||
|
||||
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
|
||||
- **Config**: `web/playwright.config.ts`
|
||||
- **Utilities**: `web/tests/e2e/utils/`
|
||||
- **Constants**: `web/tests/e2e/constants.ts`
|
||||
- **Global setup**: `web/tests/e2e/global-setup.ts`
|
||||
- **Output**: `web/output/playwright/`
|
||||
|
||||
## Imports
|
||||
|
||||
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
|
||||
```
|
||||
|
||||
All new files should be `.ts`, not `.js`.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test file
|
||||
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
|
||||
|
||||
# Run a specific project
|
||||
npx playwright test --project admin
|
||||
npx playwright test --project exclusive
|
||||
```
|
||||
|
||||
## Test Projects
|
||||
|
||||
| Project | Description | Parallelism |
|
||||
|---------|-------------|-------------|
|
||||
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
|
||||
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
|
||||
|
||||
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
|
||||
|
||||
## Authentication
|
||||
|
||||
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
|
||||
|
||||
- Server readiness check (polls health endpoint, 60s timeout)
|
||||
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
|
||||
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
|
||||
- Setting display name to `"worker"` for each worker user
|
||||
- Promoting admin2 to admin role
|
||||
- Ensuring a public LLM provider exists
|
||||
|
||||
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
|
||||
|
||||
When a test needs a different user, use API-based login — never drive the login UI:
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin2");
|
||||
|
||||
// Log in as the worker-specific user (preferred for test isolation):
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
Tests start pre-authenticated as admin — navigate and test directly:
|
||||
|
||||
```typescript
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
test.describe("Feature Name", () => {
|
||||
test("should describe expected behavior clearly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// Already authenticated as admin — go straight to testing
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
|
||||
|
||||
```typescript
|
||||
import { test } from "@playwright/test";
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.beforeEach(async ({ page }, testInfo) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
});
|
||||
```
|
||||
|
||||
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
|
||||
|
||||
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
|
||||
|
||||
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
|
||||
|
||||
## Key Utilities
|
||||
|
||||
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
|
||||
|
||||
Backend API client for test setup/teardown. Key methods:
|
||||
|
||||
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
|
||||
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
|
||||
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
|
||||
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
|
||||
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
|
||||
- **Chat**: `createChatSession()`, `deleteChatSession()`
|
||||
|
||||
### `chatActions` (`@tests/e2e/utils/chatActions`)
|
||||
|
||||
- `sendMessage(page, message)` — sends a message and waits for AI response
|
||||
- `startNewChat(page)` — clicks new-chat button and waits for intro
|
||||
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
|
||||
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
|
||||
- `switchModel(page, modelName)` — switches LLM model via popover
|
||||
|
||||
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
|
||||
|
||||
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
|
||||
- `expectElementScreenshot(locator, { name, mask?, hide? })`
|
||||
- Controlled by `VISUAL_REGRESSION=true` env var
|
||||
|
||||
### `theme` (`@tests/e2e/utils/theme`)
|
||||
|
||||
- `THEMES` — `["light", "dark"] as const` array for iterating over both themes
|
||||
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
|
||||
|
||||
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
|
||||
|
||||
```typescript
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Feature (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
test("renders correctly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await expectScreenshot(page, { name: `feature-${theme}` });
|
||||
});
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### `tools` (`@tests/e2e/utils/tools`)
|
||||
|
||||
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
|
||||
- `openActionManagement(page)` — opens the tool management popover
|
||||
|
||||
## Locator Strategy
|
||||
|
||||
Use locators in this priority order:
|
||||
|
||||
1. **`data-testid` / `aria-label`** — preferred for Onyx components
|
||||
```typescript
|
||||
page.getByTestId("AppSidebar/new-session")
|
||||
page.getByLabel("admin-page-title")
|
||||
```
|
||||
|
||||
2. **Role-based** — for standard HTML elements
|
||||
```typescript
|
||||
page.getByRole("button", { name: "Create" })
|
||||
page.getByRole("dialog")
|
||||
```
|
||||
|
||||
3. **Text/Label** — for visible text content
|
||||
```typescript
|
||||
page.getByText("Custom Assistant")
|
||||
page.getByLabel("Email")
|
||||
```
|
||||
|
||||
4. **CSS selectors** — last resort, only when above won't work
|
||||
```typescript
|
||||
page.locator('input[name="name"]')
|
||||
page.locator("#onyx-chat-input-textarea")
|
||||
```
|
||||
|
||||
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
|
||||
|
||||
## Assertions
|
||||
|
||||
Use web-first assertions — they auto-retry until the condition is met:
|
||||
|
||||
```typescript
|
||||
// Visibility
|
||||
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Text content
|
||||
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
|
||||
|
||||
// Count
|
||||
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
|
||||
|
||||
// URL
|
||||
await expect(page).toHaveURL(/chatId=/);
|
||||
|
||||
// Element state
|
||||
await expect(toggle).toBeChecked();
|
||||
await expect(button).toBeEnabled();
|
||||
```
|
||||
|
||||
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
|
||||
|
||||
## Waiting Strategy
|
||||
|
||||
```typescript
|
||||
// Wait for load state after navigation
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Wait for specific element
|
||||
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Wait for URL change
|
||||
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
|
||||
|
||||
// Wait for network response
|
||||
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
|
||||
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
|
||||
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
|
||||
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
|
||||
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
|
||||
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
|
||||
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
|
||||
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
|
||||
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
|
||||
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does
|
||||
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: "Build Backend Image"
|
||||
description: "Builds and pushes the backend Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
docker-no-cache:
|
||||
description: "Set to 'true' to disable docker build cache"
|
||||
required: false
|
||||
default: "false"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
|
||||
no-cache: ${{ inputs.docker-no-cache == 'true' }}
|
||||
75
.github/actions/build-integration-image/action.yml
vendored
Normal file
75
.github/actions/build-integration-image/action.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: "Build Integration Image"
|
||||
description: "Builds and pushes the integration test image with docker bake"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
name: "Build Model Server Image"
|
||||
description: "Builds and pushes the model server Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max
|
||||
118
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
118
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
name: "Run Nightly Provider Chat Test"
|
||||
description: "Starts required compose services and runs nightly provider integration test"
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
|
||||
required: true
|
||||
models:
|
||||
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: false
|
||||
default: ""
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
api-base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
default: ""
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in image tags"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
shell: bash
|
||||
env:
|
||||
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
run: |
|
||||
cat <<EOF2 > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
|
||||
- name: Start Docker containers
|
||||
shell: bash
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server
|
||||
|
||||
- name: Run nightly provider integration test
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
env:
|
||||
MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
49
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
49
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
151
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,151 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
161
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
161
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR once so we can gate behavior and infer preferred actor.
|
||||
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
|
||||
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
|
||||
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
|
||||
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
|
||||
reason="command-failed"
|
||||
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
|
||||
reason="merge-conflict"
|
||||
fi
|
||||
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "details<<EOF"
|
||||
tail -n 40 "$output_file"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
@@ -45,9 +45,6 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -118,9 +115,9 @@ jobs:
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -129,7 +126,6 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
|
||||
5
.github/workflows/pr-helm-chart-testing.yml
vendored
5
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,8 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
@@ -92,7 +91,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,6 +20,7 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
@@ -423,6 +424,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
@@ -443,6 +445,7 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
@@ -701,6 +704,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
|
||||
9
.github/workflows/pr-playwright-tests.yml
vendored
9
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,6 +22,9 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -300,6 +303,7 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
@@ -589,7 +593,10 @@ jobs:
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
if: >-
|
||||
always() &&
|
||||
github.event_name == 'pull_request' &&
|
||||
needs.playwright-tests.result != 'cancelled'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
|
||||
@@ -89,6 +89,10 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
217
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
217
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,217 @@
|
||||
name: Reusable Nightly LLM Provider Chat Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
strict:
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
required: false
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose down -v
|
||||
13
.github/workflows/zizmor.yml
vendored
13
.github/workflows/zizmor.yml
vendored
@@ -5,6 +5,8 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
paths:
|
||||
- ".github/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
@@ -21,29 +23,18 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,6 +7,7 @@
|
||||
.zed
|
||||
.cursor
|
||||
!/.cursor/mcp.json
|
||||
!/.cursor/skills/
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
|
||||
@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
@@ -616,3 +616,9 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
|
||||
@@ -21,15 +21,14 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, NamedTuple
|
||||
from typing import NamedTuple
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
|
||||
return script.get_current_head()
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: List[str], head_rev: str
|
||||
) -> List[str]:
|
||||
"""Return only schemas whose current alembic version is not at head."""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Find which schemas actually have an alembic_version table
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"SELECT table_schema FROM information_schema.tables "
|
||||
"WHERE table_name = 'alembic_version' "
|
||||
"AND table_schema = ANY(:schemas)"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
schemas_with_table = set(row[0] for row in rows)
|
||||
|
||||
# Schemas without the table definitely need migration
|
||||
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
|
||||
|
||||
if not schemas_with_table:
|
||||
return needs_migration
|
||||
|
||||
# Validate schema names before interpolating into SQL
|
||||
for schema in schemas_with_table:
|
||||
if not is_valid_schema_name(schema):
|
||||
raise ValueError(f"Invalid schema name: {schema}")
|
||||
|
||||
# Single query to get every schema's current revision at once.
|
||||
# Use integer tags instead of interpolating schema names into
|
||||
# string literals to avoid quoting issues.
|
||||
schema_list = list(schemas_with_table)
|
||||
union_parts = [
|
||||
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
|
||||
for i, schema in enumerate(schema_list)
|
||||
]
|
||||
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
|
||||
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
|
||||
|
||||
needs_migration.extend(
|
||||
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
|
||||
)
|
||||
|
||||
return needs_migration
|
||||
|
||||
|
||||
def run_migrations_parallel(
|
||||
schemas: list[str],
|
||||
max_workers: int,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""code interpreter seed
|
||||
|
||||
Revision ID: 07b98176f1de
|
||||
Revises: 7cb492013621
|
||||
Create Date: 2026-02-23 15:55:07.606784
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "07b98176f1de"
|
||||
down_revision = "7cb492013621"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Seed the single instance of code_interpreter_server
|
||||
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
|
||||
op.execute(
|
||||
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(sa.text("DELETE FROM code_interpreter_server"))
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add scim_username to scim_user_mapping
|
||||
|
||||
Revision ID: 0bb4558f35df
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-20 10:45:30.340188
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0bb4558f35df"
|
||||
down_revision = "631fd2504136"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_username", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_username")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add approx_chunk_count_in_vespa to opensearch tenant migration
|
||||
|
||||
Revision ID: 631fd2504136
|
||||
Revises: c7f2e1b4a9d3
|
||||
Create Date: 2026-02-18 21:07:52.831215
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "631fd2504136"
|
||||
down_revision = "c7f2e1b4a9d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"approx_chunk_count_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add enterprise and name fields to scim_user_mapping
|
||||
|
||||
Revision ID: 7616121f6e97
|
||||
Revises: 07b98176f1de
|
||||
Create Date: 2026-02-23 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7616121f6e97"
|
||||
down_revision = "07b98176f1de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("department", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("manager", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("given_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("family_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_emails_json", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_emails_json")
|
||||
op.drop_column("scim_user_mapping", "family_name")
|
||||
op.drop_column("scim_user_mapping", "given_name")
|
||||
op.drop_column("scim_user_mapping", "manager")
|
||||
op.drop_column("scim_user_mapping", "department")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""code interpreter server model
|
||||
|
||||
Revision ID: 7cb492013621
|
||||
Revises: 0bb4558f35df
|
||||
Create Date: 2026-02-22 18:54:54.007265
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7cb492013621"
|
||||
down_revision = "0bb4558f35df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"code_interpreter_server",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column(
|
||||
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("code_interpreter_server")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -0,0 +1,70 @@
|
||||
"""llm provider deprecate fields
|
||||
|
||||
Revision ID: c0c937d5c9e5
|
||||
Revises: 8ffcc2bcfc11
|
||||
Create Date: 2026-02-25 17:35:46.125102
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0c937d5c9e5"
|
||||
down_revision = "8ffcc2bcfc11"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
|
||||
op.drop_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
type_="unique",
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore unique constraint on is_default_provider
|
||||
op.create_unique_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
["is_default_provider"],
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add sharing_scope to build_session
|
||||
|
||||
Revision ID: c7f2e1b4a9d3
|
||||
Revises: 19c0ccb01687
|
||||
Create Date: 2026-02-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c7f2e1b4a9d3"
|
||||
down_revision = "19c0ccb01687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"build_session",
|
||||
sa.Column(
|
||||
"sharing_scope",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("build_session", "sharing_scope")
|
||||
@@ -263,9 +263,15 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
source=source,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
709
backend/ee/onyx/db/scim.py
Normal file
709
backend/ee/onyx/db/scim.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""SCIM Data Access Layer.
|
||||
|
||||
All database operations for SCIM provisioning — token management, user
|
||||
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
|
||||
|
||||
Usage from FastAPI::
|
||||
|
||||
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
@router.post("/tokens")
|
||||
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
|
||||
token = dal.create_token(name=..., hashed_token=..., ...)
|
||||
dal.commit()
|
||||
return token
|
||||
|
||||
Usage from background tasks::
|
||||
|
||||
with ScimDAL.from_tenant("tenant_abc") as dal:
|
||||
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
|
||||
dal.commit()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import SQLColumnExpression
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimDAL(DAL):
|
||||
"""Data Access Layer for SCIM provisioning operations.
|
||||
|
||||
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
|
||||
when you want to persist changes. This follows the existing ``_no_commit``
|
||||
convention and lets callers batch multiple operations into one transaction.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
name: str,
|
||||
hashed_token: str,
|
||||
token_display: str,
|
||||
created_by_id: UUID,
|
||||
) -> ScimToken:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — this method automatically revokes
|
||||
all existing active tokens before creating the new one.
|
||||
"""
|
||||
# Revoke any currently active tokens
|
||||
active_tokens = list(
|
||||
self._session.scalars(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
).all()
|
||||
)
|
||||
for t in active_tokens:
|
||||
t.is_active = False
|
||||
|
||||
token = ScimToken(
|
||||
name=name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=created_by_id,
|
||||
)
|
||||
self._session.add(token)
|
||||
self._session.flush()
|
||||
return token
|
||||
|
||||
def get_active_token(self) -> ScimToken | None:
|
||||
"""Return the single currently active token, or None."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
)
|
||||
|
||||
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
|
||||
"""Look up a token by its SHA-256 hash."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
|
||||
)
|
||||
|
||||
def revoke_token(self, token_id: int) -> None:
|
||||
"""Deactivate a token by ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token does not exist.
|
||||
"""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if not token:
|
||||
raise ValueError(f"SCIM token with id {token_id} not found")
|
||||
token.is_active = False
|
||||
|
||||
def update_token_last_used(self, token_id: int) -> None:
|
||||
"""Update the last_used_at timestamp for a token."""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if token:
|
||||
token.last_used_at = func.now() # type: ignore[assignment]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_user_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the Onyx user ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
|
||||
)
|
||||
|
||||
def list_user_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimUserMapping], int]:
|
||||
"""List user mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimUserMapping)
|
||||
.order_by(ScimUserMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def update_user_mapping_external_id(
|
||||
self,
|
||||
mapping_id: int,
|
||||
external_id: str,
|
||||
) -> ScimUserMapping:
|
||||
"""Update the external ID on a user mapping.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mapping does not exist.
|
||||
"""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
|
||||
mapping.external_id = external_id
|
||||
return mapping
|
||||
|
||||
def delete_user_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a user mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_user(self, user_id: UUID) -> User | None:
|
||||
"""Fetch a user by ID."""
|
||||
return self._session.scalar(
|
||||
select(User).where(User.id == user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Fetch a user by email (case-insensitive)."""
|
||||
return self._session.scalar(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
|
||||
def add_user(self, user: User) -> None:
|
||||
"""Add a new user to the session and flush to assign an ID."""
|
||||
self._session.add(user)
|
||||
self._session.flush()
|
||||
|
||||
def update_user(
|
||||
self,
|
||||
user: User,
|
||||
*,
|
||||
email: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
personal_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update user attributes. Only sets fields that are provided."""
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if is_active is not None:
|
||||
user.is_active = is_active
|
||||
if personal_name is not None:
|
||||
user.personal_name = personal_name
|
||||
|
||||
def deactivate_user(self, user: User) -> None:
|
||||
"""Mark a user as inactive."""
|
||||
user.is_active = False
|
||||
|
||||
def list_users(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, mapping) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "username":
|
||||
# arg-type: fastapi-users types User.email as str, not a column expression
|
||||
# assignment: union return type widens but query is still Select[tuple[User]]
|
||||
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
|
||||
elif attr == "active":
|
||||
query = query.where(
|
||||
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
|
||||
)
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
# Count total matching rows first, then paginate. SCIM uses 1-based
|
||||
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Batch-fetch SCIM mappings to avoid N+1 queries
|
||||
mapping_map = self._get_user_mappings_batch([u.id for u in users])
|
||||
return [(u, mapping_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(
|
||||
self,
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, ScimUserMapping]:
|
||||
"""Batch-fetch SCIM user mappings keyed by user ID."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m for m in mappings}
|
||||
|
||||
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
|
||||
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
|
||||
|
||||
Excludes groups marked for deletion.
|
||||
"""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
|
||||
).all()
|
||||
|
||||
group_ids = [r.user_group_id for r in rels]
|
||||
if not group_ids:
|
||||
return []
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
return [(g.id, g.name) for g in groups]
|
||||
|
||||
def get_users_groups_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Batch-fetch group memberships for multiple users.
|
||||
|
||||
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
|
||||
Avoids N+1 queries when building user list responses.
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
group_ids = list({r.user_group_id for r in rels})
|
||||
if not group_ids:
|
||||
return {}
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
groups_by_id = {g.id: g.name for g in groups}
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {}
|
||||
for r in rels:
|
||||
if r.user_id and r.user_group_id in groups_by_id:
|
||||
result.setdefault(r.user_id, []).append(
|
||||
(r.user_group_id, groups_by_id[r.user_group_id])
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_group_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_group_id: int,
|
||||
) -> ScimGroupMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user group."""
|
||||
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_group_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_group_mapping_by_group_id(
|
||||
self, user_group_id: int
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the Onyx user group ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
|
||||
def list_group_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimGroupMapping], int]:
|
||||
"""List group mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimGroupMapping)
|
||||
.order_by(ScimGroupMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def delete_group_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a group mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimGroupMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_group(self, group_id: int) -> UserGroup | None:
|
||||
"""Fetch a group by ID, returning None if deleted or missing."""
|
||||
group = self._session.get(UserGroup, group_id)
|
||||
if group and group.is_up_for_deletion:
|
||||
return None
|
||||
return group
|
||||
|
||||
def get_group_by_name(self, name: str) -> UserGroup | None:
|
||||
"""Fetch a group by exact name."""
|
||||
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
|
||||
|
||||
def add_group(self, group: UserGroup) -> None:
|
||||
"""Add a new group to the session and flush to assign an ID."""
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Update group attributes and set the modification timestamp."""
|
||||
if name is not None:
|
||||
group.name = name
|
||||
group.time_last_modified_by_user = func.now()
|
||||
|
||||
def delete_group(self, group: UserGroup) -> None:
|
||||
"""Delete a group from the session."""
|
||||
self._session.delete(group)
|
||||
|
||||
def list_groups(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[UserGroup, str | None]], int]:
|
||||
"""Query groups with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (group, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "displayname":
|
||||
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
|
||||
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(UserGroup.id == mapping.user_group_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
groups = list(
|
||||
self._session.scalars(
|
||||
query.order_by(UserGroup.id).offset(offset).limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
ext_id_map = self._get_group_external_ids([g.id for g in groups])
|
||||
return [(g, ext_id_map.get(g.id)) for g in groups], total
|
||||
|
||||
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
|
||||
"""Get group members as (user_id, email) pairs."""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
).all()
|
||||
|
||||
user_ids = [r.user_id for r in rels if r.user_id]
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
(
|
||||
r.user_id,
|
||||
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
|
||||
)
|
||||
for r in rels
|
||||
if r.user_id
|
||||
]
|
||||
|
||||
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
|
||||
"""Return the subset of UUIDs that don't exist as users.
|
||||
|
||||
Returns an empty list if all IDs are valid.
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Add user-group relationships, ignoring duplicates."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
pg_insert(User__UserGroup)
|
||||
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
User__UserGroup.user_group_id,
|
||||
User__UserGroup.user_id,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Replace all members of a group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
self.upsert_group_members(group_id, user_ids)
|
||||
|
||||
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Remove specific members from a group."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.in_(user_ids),
|
||||
)
|
||||
)
|
||||
|
||||
def delete_group_with_members(self, group: UserGroup) -> None:
|
||||
"""Remove all member relationships and delete the group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
|
||||
)
|
||||
self._session.delete(group)
|
||||
|
||||
def sync_group_external_id(
|
||||
self, group_id: int, new_external_id: str | None
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a group."""
|
||||
mapping = self.get_group_mapping_by_group_id(group_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_group_mapping(
|
||||
external_id=new_external_id, user_group_id=group_id
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_group_mapping(mapping.id)
|
||||
|
||||
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
|
||||
"""Batch-fetch external IDs for a list of group IDs."""
|
||||
if not group_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id.in_(group_ids)
|
||||
)
|
||||
).all()
|
||||
return {m.user_group_id: m.external_id for m in mappings}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (used by DAL methods above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_scim_string_op(
|
||||
query: Select[tuple[User]] | Select[tuple[UserGroup]],
|
||||
column: SQLColumnExpression[str],
|
||||
scim_filter: ScimFilter,
|
||||
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
|
||||
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
|
||||
|
||||
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
|
||||
SQLAlchemy's operators handle LIKE-pattern escaping internally.
|
||||
"""
|
||||
val = scim_filter.value
|
||||
if scim_filter.operator == ScimFilterOperator.EQUAL:
|
||||
return query.where(func.lower(column) == val.lower())
|
||||
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
|
||||
return query.where(column.icontains(val, autoescape=True))
|
||||
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
|
||||
return query.where(column.istartswith(val, autoescape=True))
|
||||
else:
|
||||
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -18,11 +19,15 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -195,8 +200,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -209,6 +266,8 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -216,11 +275,16 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -230,7 +294,9 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -43,14 +47,27 @@ def sharepoint_group_sync(
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
# Process each site
|
||||
enumerate_all = connector_config.get(
|
||||
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
)
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
sp_domain_suffix = connector.sharepoint_domain_suffix
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
ctx = connector._create_rest_client_context(site_descriptor.url)
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
|
||||
)
|
||||
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
external_groups = get_sharepoint_external_groups(
|
||||
ctx,
|
||||
connector.graph_client,
|
||||
graph_api_base=connector.graph_api_base,
|
||||
get_access_token=connector._get_graph_access_token,
|
||||
enumerate_all_ad_groups=enumerate_all,
|
||||
)
|
||||
|
||||
# Yield each group
|
||||
for group in external_groups:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import re
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
@@ -14,7 +17,10 @@ from pydantic import BaseModel
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
|
||||
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -33,6 +39,70 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
|
||||
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
|
||||
|
||||
|
||||
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
|
||||
|
||||
|
||||
def _graph_api_get(
|
||||
url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Authenticated Graph API GET with retry on transient errors."""
|
||||
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
|
||||
access_token = get_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
resp = _requests.get(
|
||||
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
)
|
||||
if (
|
||||
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
|
||||
and attempt < GRAPH_API_MAX_RETRIES
|
||||
):
|
||||
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
|
||||
logger.warning(
|
||||
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
f"Graph API connection error on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
|
||||
)
|
||||
|
||||
|
||||
def _iter_graph_collection(
|
||||
initial_url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Paginate through a Graph API collection, yielding items one at a time."""
|
||||
url: str | None = initial_url
|
||||
while url:
|
||||
data = _graph_api_get(url, get_access_token, params)
|
||||
params = None
|
||||
yield from data.get("value", [])
|
||||
url = data.get("@odata.nextLink")
|
||||
|
||||
|
||||
def _normalize_email(email: str) -> str:
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
return email.replace(MICROSOFT_DOMAIN, "")
|
||||
return email
|
||||
|
||||
|
||||
class SharepointGroup(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
@@ -527,8 +597,12 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
# Keep percent-encoding intact so the path matches the encoding
|
||||
# used by the Office365 library's SPResPath.create_relative(),
|
||||
# which compares against urlparse(context.base_url).path.
|
||||
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
|
||||
# the site prefix in the constructed URL.
|
||||
server_relative_url = urlparse(site_url).path
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
@@ -572,8 +646,65 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
|
||||
def _enumerate_ad_groups_paginated(
|
||||
get_access_token: Callable[[], str],
|
||||
already_resolved: set[str],
|
||||
graph_api_base: str,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
|
||||
|
||||
Skips groups whose suffixed name is already in *already_resolved*.
|
||||
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
|
||||
"""
|
||||
groups_url = f"{graph_api_base}/groups"
|
||||
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
|
||||
total_groups = 0
|
||||
|
||||
for group_json in _iter_graph_collection(
|
||||
groups_url, get_access_token, groups_params
|
||||
):
|
||||
group_id: str = group_json.get("id", "")
|
||||
display_name: str = group_json.get("displayName", "")
|
||||
if not group_id or not display_name:
|
||||
continue
|
||||
|
||||
total_groups += 1
|
||||
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
|
||||
"groups — stopping to avoid excessive memory/API usage. "
|
||||
"Remaining groups will be resolved from role assignments only."
|
||||
)
|
||||
return
|
||||
|
||||
name = f"{display_name}_{group_id}"
|
||||
if name in already_resolved:
|
||||
continue
|
||||
|
||||
member_emails: list[str] = []
|
||||
members_url = f"{graph_api_base}/groups/{group_id}/members"
|
||||
members_params: dict[str, str] = {
|
||||
"$select": "userPrincipalName,mail",
|
||||
"$top": "999",
|
||||
}
|
||||
for member_json in _iter_graph_collection(
|
||||
members_url, get_access_token, members_params
|
||||
):
|
||||
email = member_json.get("userPrincipalName") or member_json.get("mail")
|
||||
if email:
|
||||
member_emails.append(_normalize_email(email))
|
||||
|
||||
yield ExternalUserGroup(id=name, user_emails=member_emails)
|
||||
|
||||
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
|
||||
|
||||
|
||||
def get_sharepoint_external_groups(
|
||||
client_context: ClientContext, graph_client: GraphClient
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
graph_api_base: str,
|
||||
get_access_token: Callable[[], str] | None = None,
|
||||
enumerate_all_ad_groups: bool = False,
|
||||
) -> list[ExternalUserGroup]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
@@ -629,57 +760,22 @@ def get_sharepoint_external_groups(
|
||||
client_context, graph_client, groups, is_group_sync=True
|
||||
)
|
||||
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
graph_client.groups.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups",
|
||||
)
|
||||
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
|
||||
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
|
||||
ad_groups_to_emails: dict[str, set[str]] = {}
|
||||
for group in azure_ad_groups:
|
||||
# If the group is already identified, we don't need to get the members
|
||||
if group.display_name in identified_groups:
|
||||
continue
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
name = group.display_name
|
||||
name = _get_group_name_with_suffix(group.id, name, graph_client)
|
||||
external_user_groups: list[ExternalUserGroup] = [
|
||||
ExternalUserGroup(id=group_name, user_emails=list(emails))
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items()
|
||||
]
|
||||
|
||||
members = sleep_and_retry(
|
||||
group.members.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
|
||||
if not enumerate_all_ad_groups or get_access_token is None:
|
||||
logger.info(
|
||||
"Skipping exhaustive Azure AD group enumeration. "
|
||||
"Only groups found in site role assignments are included."
|
||||
)
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
if not ad_groups_to_emails.get(name):
|
||||
ad_groups_to_emails[name] = set()
|
||||
if user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
ad_groups_to_emails[name].add(user_principal_name)
|
||||
elif mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
ad_groups_to_emails[name].add(mail)
|
||||
return external_user_groups
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
for group_name, emails in ad_groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
already_resolved = set(groups_and_members.groups_to_emails.keys())
|
||||
for group in _enumerate_ad_groups_paginated(
|
||||
get_access_token, already_resolved, graph_api_base
|
||||
):
|
||||
external_user_groups.append(group)
|
||||
|
||||
return external_user_groups
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
@@ -162,6 +163,11 @@ def get_application() -> FastAPI:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
|
||||
@@ -5,6 +5,11 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
|
||||
|
||||
|
||||
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
|
||||
# before bearer token configuration is complete
|
||||
("/scim/v2/ServiceProviderConfig", {"GET"}),
|
||||
("/scim/v2/ResourceTypes", {"GET"}),
|
||||
("/scim/v2/Schemas", {"GET"}),
|
||||
# needs to be accessible prior to user login
|
||||
("/enterprise-settings", {"GET"}),
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
@@ -22,6 +23,10 @@ from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import store_settings
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
|
||||
from ee.onyx.server.scim.models import ScimTokenResponse
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
@@ -198,3 +203,63 @@ def upload_custom_analytics_script(
|
||||
@basic_router.get("/custom-analytics-script")
|
||||
def fetch_custom_analytics_script() -> str | None:
|
||||
return load_analytics_script()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM token management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
@admin_router.get("/scim/token")
|
||||
def get_active_scim_token(
|
||||
_: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenResponse:
|
||||
"""Return the currently active SCIM token's metadata, or 404 if none."""
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/scim/token", status_code=201)
|
||||
def create_scim_token(
|
||||
body: ScimTokenCreate,
|
||||
user: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenCreatedResponse:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — creating a new token automatically
|
||||
revokes all previous tokens. The raw token value is returned exactly once
|
||||
in the response; it cannot be retrieved again.
|
||||
"""
|
||||
raw_token, hashed_token, token_display = generate_scim_token()
|
||||
token = dal.create_token(
|
||||
name=body.name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
dal.commit()
|
||||
|
||||
return ScimTokenCreatedResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
raw_token=raw_token,
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
num_hits: int = 30
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
957
backend/ee/onyx/server/scim/api.py
Normal file
957
backend/ee/onyx/server/scim/api.py
Normal file
@@ -0,0 +1,957 @@
|
||||
"""SCIM 2.0 API endpoints (RFC 7644).
|
||||
|
||||
This module provides the FastAPI router for SCIM service discovery,
|
||||
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
|
||||
these endpoints to provision and manage users and groups.
|
||||
|
||||
Service discovery endpoints are unauthenticated — IdPs may probe them
|
||||
before bearer token configuration is complete. All other endpoints
|
||||
require a valid SCIM bearer token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
"""Resolve the SCIM provider for the current request.
|
||||
|
||||
Currently returns OktaProvider for all requests. When multi-provider
|
||||
support is added (ENG-3652), this will resolve based on token metadata
|
||||
or tenant configuration — no endpoint changes required.
|
||||
"""
|
||||
return get_default_provider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/ServiceProviderConfig")
|
||||
def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
"""Advertise supported SCIM features (RFC 7643 §5)."""
|
||||
return SERVICE_PROVIDER_CONFIG
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return ScimJSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
user = dal.get_user(uid)
|
||||
if not user:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
|
||||
def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""Extract a display name string from a SCIM name object.
|
||||
|
||||
Returns None if no name is provided, so the caller can decide
|
||||
whether to update the user's personal_name.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_list_response(
|
||||
resources: list[ScimUserResource | ScimGroupResource],
|
||||
total: int,
|
||||
start_index: int,
|
||||
count: int,
|
||||
excluded: set[str] | None = None,
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""Build a SCIM list response, optionally applying attribute exclusions.
|
||||
|
||||
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
|
||||
the ``excludedAttributes`` query parameter.
|
||||
"""
|
||||
if excluded:
|
||||
envelope = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = envelope.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
try:
|
||||
dal.add_user(user)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
``PATCH {"active": false}`` rather than DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched, ent_data = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Track the scim_username — if userName was patched, update it
|
||||
new_scim_username = patched.userName.strip() if patched.userName else None
|
||||
|
||||
# If displayName was explicitly patched (different from the original), use
|
||||
# it as personal_name directly. Otherwise, derive from name components.
|
||||
personal_name: str | None
|
||||
if patched.displayName and patched.displayName != current.displayName:
|
||||
personal_name = patched.displayName
|
||||
else:
|
||||
personal_name = _scim_name_to_str(patched.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip()
|
||||
if patched.userName.strip().lower() != user.email.lower()
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails)
|
||||
if patched.emails is not None
|
||||
else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
group = dal.get_group(gid)
|
||||
if not group:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
return group
|
||||
|
||||
|
||||
def _parse_member_uuids(
|
||||
members: list[ScimGroupMember],
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse member value strings to UUIDs.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids: list[UUID] = []
|
||||
for m in members:
|
||||
try:
|
||||
uuids.append(UUID(m.value))
|
||||
except ValueError:
|
||||
return [], f"Invalid member ID: {m.value}"
|
||||
return uuids, None
|
||||
|
||||
|
||||
def _validate_and_parse_members(
|
||||
members: list[ScimGroupMember], dal: ScimDAL
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse and validate member UUIDs exist in the database.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids, err = _parse_member_uuids(members)
|
||||
if err:
|
||||
return [], err
|
||||
|
||||
if uuids:
|
||||
missing = dal.validate_member_ids(uuids)
|
||||
if missing:
|
||||
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
|
||||
|
||||
return uuids, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
resource = provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
db_group = UserGroup(
|
||||
name=group_resource.displayName,
|
||||
is_up_to_date=True,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
try:
|
||||
dal.add_group(db_group)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = provider.build_group_resource(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
missing = dal.validate_member_ids(add_uuids)
|
||||
if missing:
|
||||
return _scim_error_response(
|
||||
400,
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
def _is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
try:
|
||||
UUID(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
104
backend/ee/onyx/server/scim/auth.py
Normal file
104
backend/ee/onyx/server/scim/auth.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""SCIM bearer token authentication.
|
||||
|
||||
SCIM endpoints are authenticated via bearer tokens that admins create in the
|
||||
Onyx UI. This module provides:
|
||||
|
||||
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
|
||||
validates the token from the Authorization header.
|
||||
- ``generate_scim_token``: Creates a new cryptographically random token
|
||||
and returns the raw value, its SHA-256 hash, and a display suffix.
|
||||
|
||||
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
|
||||
URL-safe base64 from ``secrets.token_urlsafe``.
|
||||
|
||||
The hash is stored in the ``scim_token`` table; the raw value is shown to
|
||||
the admin exactly once at creation time.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
|
||||
def _hash_scim_token(token: str) -> str:
|
||||
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def generate_scim_token() -> tuple[str, str, str]:
|
||||
"""Generate a new SCIM bearer token.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(raw_token, hashed_token, token_display)`` where
|
||||
``token_display`` is a masked version showing only the last 4 chars.
|
||||
"""
|
||||
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
|
||||
hashed_token = _hash_scim_token(raw_token)
|
||||
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
|
||||
return raw_token, hashed_token, token_display
|
||||
|
||||
|
||||
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
|
||||
"""Extract and hash a SCIM token from the request Authorization header."""
|
||||
return get_hashed_bearer_token_from_request(
|
||||
request,
|
||||
valid_prefixes=[SCIM_TOKEN_PREFIX],
|
||||
hash_fn=_hash_scim_token,
|
||||
)
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
def verify_scim_token(
|
||||
request: Request,
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimToken:
|
||||
"""FastAPI dependency that authenticates SCIM requests.
|
||||
|
||||
Extracts the bearer token from the Authorization header, hashes it,
|
||||
looks it up in the database, and verifies it is active.
|
||||
|
||||
Note:
|
||||
This dependency does NOT update ``last_used_at`` — the endpoint
|
||||
should do that via ``ScimDAL.update_token_last_used()`` so the
|
||||
timestamp write is part of the endpoint's transaction.
|
||||
|
||||
Raises:
|
||||
HTTPException(401): If the token is missing, invalid, or inactive.
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
|
||||
return token
|
||||
@@ -7,12 +7,14 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -30,6 +32,10 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -62,6 +68,43 @@ class ScimMeta(BaseModel):
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserGroupRef(BaseModel):
|
||||
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
|
||||
|
||||
value: str
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -70,14 +113,22 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
displayName: str | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
@@ -120,12 +171,53 @@ class ScimPatchOperationType(str, Enum):
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchResourceValue(BaseModel):
|
||||
"""Partial resource dict for path-less PATCH replace operations.
|
||||
|
||||
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
|
||||
of resource attributes to set. IdPs may include read-only fields
|
||||
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
|
||||
stripped by the provider's ``ignored_patch_paths`` before processing.
|
||||
|
||||
``extra="allow"`` lets unknown attributes pass through so the patch
|
||||
handler can decide what to do with them (ignore or reject).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
active: bool | None = None
|
||||
userName: str | None = None
|
||||
displayName: str | None = None
|
||||
externalId: str | None = None
|
||||
name: ScimName | None = None
|
||||
members: list[ScimGroupMember] | None = None
|
||||
id: str | None = None
|
||||
schemas: list[str] | None = None
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_operation(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
@@ -195,10 +287,39 @@ class ScimServiceProviderConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaAttribute(BaseModel):
|
||||
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
multiValued: bool = False
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
caseExact: bool = False
|
||||
mutability: str = "readWrite"
|
||||
returned: str = "default"
|
||||
uniqueness: str = "none"
|
||||
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaDefinition(BaseModel):
|
||||
"""SCIM Schema definition (RFC 7643 §7).
|
||||
|
||||
Served at GET /scim/v2/Schemas. Describes the attributes available
|
||||
on each resource type so IdPs know which fields they can provision.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
@@ -211,7 +332,7 @@ class ScimResourceType(BaseModel):
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
|
||||
@@ -14,13 +14,70 @@ responsible for persisting changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[.+\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -31,94 +88,223 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
) -> ScimUserResource:
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
value: ScimPatchValue,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via an email filter path."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
# Unknown enterprise attributes are silently ignored rather than
|
||||
# rejected — IdPs may send attributes we don't model yet.
|
||||
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current group resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
@@ -133,7 +319,9 @@ def apply_group_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
_apply_group_replace(
|
||||
op, data, current_members, added_ids, removed_ids, ignored_paths
|
||||
)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
@@ -154,38 +342,48 @@ def _apply_group_replace(
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
dumped = op.value.model_dump(exclude_unset=True)
|
||||
for key, val in dumped.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
_set_group_field(key.lower(), val, data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
_replace_members(
|
||||
_members_to_dicts(op.value), current_members, added_ids, removed_ids
|
||||
)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
_set_group_field(path, op.value, data, ignored_paths)
|
||||
|
||||
|
||||
def _members_to_dicts(
|
||||
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
|
||||
) -> list[dict]:
|
||||
"""Convert a member list value to a list of dicts for internal processing."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
return [m.model_dump(exclude_none=True) for m in value]
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
value: list[dict],
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
@@ -197,16 +395,21 @@ def _replace_members(
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
@@ -223,8 +426,10 @@ def _apply_group_add(
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
for member_data in member_dicts:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
|
||||
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
210
backend/ee/onyx/server/scim/providers/base.py
Normal file
210
backend/ee/onyx/server/scim/providers/base.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Base SCIM provider abstraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
Subclass this to handle IdP-specific quirks. The base class provides
|
||||
RFC 7643-compliant response builders that populate all standard fields.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Short identifier for this provider (e.g. ``"okta"``)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
|
||||
|
||||
IdPs may include read-only or meta fields alongside actual changes
|
||||
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
|
||||
here are silently dropped instead of raising an error.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
Args:
|
||||
user: The Onyx user model.
|
||||
external_id: The IdP's external identifier for this user.
|
||||
groups: List of ``(group_id, group_name)`` tuples for the
|
||||
``groups`` read-only attribute. Pass ``None`` or ``[]``
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
username = scim_username or user.email
|
||||
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=name,
|
||||
displayName=user.personal_name,
|
||||
emails=emails,
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Build a SCIM Group response from an Onyx UserGroup."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
Currently returns ``OktaProvider`` since Okta is the primary supported
|
||||
IdP. When provider detection is added (via token metadata or tenant
|
||||
config), this can be replaced with dynamic resolution.
|
||||
"""
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
return OktaProvider()
|
||||
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
26
backend/ee/onyx/server/scim/providers/okta.py
Normal file
26
backend/ee/onyx/server/scim/providers/okta.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Okta SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
class OktaProvider(ScimProvider):
|
||||
"""Okta SCIM provider.
|
||||
|
||||
Okta behavioral notes:
|
||||
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
|
||||
- Sends path-less PATCH with value dicts containing extra fields
|
||||
(``id``, ``schemas``)
|
||||
- Expects ``displayName`` and ``groups`` in user responses
|
||||
- Only uses ``eq`` operator for ``userName`` filter
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "okta"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
173
backend/ee/onyx/server/scim/schema_definitions.py
Normal file
173
backend/ee/onyx/server/scim/schema_definitions.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
|
||||
|
||||
Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaAttribute
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
|
||||
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
|
||||
|
||||
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/scim/v2/Groups",
|
||||
"description": "SCIM Group resource",
|
||||
"schema": SCIM_GROUP_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_USER_SCHEMA,
|
||||
name="User",
|
||||
description="SCIM core User schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="userName",
|
||||
type="string",
|
||||
required=True,
|
||||
uniqueness="server",
|
||||
description="Unique identifier for the user, typically an email address.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="name",
|
||||
type="complex",
|
||||
description="The components of the user's name.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="givenName",
|
||||
type="string",
|
||||
description="The user's first name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="familyName",
|
||||
type="string",
|
||||
description="The user's last name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="formatted",
|
||||
type="string",
|
||||
description="The full name, including all middle names and titles.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="emails",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Email addresses for the user.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Email address value.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="type",
|
||||
type="string",
|
||||
description="Label for this email (e.g. 'work').",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="primary",
|
||||
type="boolean",
|
||||
description="Whether this is the primary email.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="active",
|
||||
type="boolean",
|
||||
description="Whether the user account is active.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
description="SCIM core Group schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="displayName",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Human-readable name for the group.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="members",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Members of the group.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="User ID of the group member.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="display",
|
||||
type="string",
|
||||
mutability="readOnly",
|
||||
description="Display name of the group member.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1,10 +1,13 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -41,6 +44,14 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -82,6 +93,18 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -90,7 +113,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license found.
|
||||
# No license found in cache or DB.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
|
||||
@@ -37,12 +37,15 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -121,6 +121,7 @@ from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -137,6 +138,8 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
return user.role == UserRole.ADMIN
|
||||
@@ -208,22 +211,34 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def workspace_invite_only_enabled() -> bool:
|
||||
settings = load_settings()
|
||||
return settings.invite_only_enabled
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
|
||||
# SSO providers manage membership; allow JIT provisioning regardless of invites
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
if not workspace_invite_only_enabled():
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -240,7 +255,13 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
@@ -256,13 +277,32 @@ def verify_email_domain(email: str) -> None:
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
@@ -1650,7 +1690,10 @@ def get_oauth_router(
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
# Use WEB_DOMAIN instead of request.url_for() to prevent host
|
||||
# header poisoning — request.url_for() trusts the Host header.
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""Celery tasks for hierarchy fetching."""
|
||||
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
check_for_hierarchy_fetching,
|
||||
)
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
connector_hierarchy_fetching_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]
|
||||
|
||||
@@ -41,3 +41,14 @@ assert (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
|
||||
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
|
||||
# WARNING: Do not change these values without knowing what changes also need to
|
||||
# be made to OpenSearchTenantMigrationRecord.
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 500
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT = 4
|
||||
|
||||
# String used to indicate in the vespa_visit_continuation_token mapping that the
|
||||
# slice has finished and there is nothing left to visit.
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
|
||||
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
|
||||
)
|
||||
|
||||
@@ -8,6 +8,12 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
@@ -42,12 +48,19 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
def is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map: dict[int, str | None],
|
||||
) -> bool:
|
||||
return all(
|
||||
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
|
||||
for continuation_token in continuation_token_map.values()
|
||||
)
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@@ -76,11 +89,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
|
||||
per-document), transform them, and index them into OpenSearch. Progress is
|
||||
tracked via a continuation token stored in the
|
||||
tracked via a continuation token map stored in the
|
||||
OpenSearchTenantMigrationRecord.
|
||||
|
||||
The first time we see no continuation token and non-zero chunks migrated, we
|
||||
consider the migration complete and all subsequent invocations are no-ops.
|
||||
The first time we see no continuation token map and non-zero chunks
|
||||
migrated, we consider the migration complete and all subsequent invocations
|
||||
are no-ops.
|
||||
|
||||
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
|
||||
where progress is tracked for each slice.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -133,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
@@ -153,15 +174,28 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token,
|
||||
continuation_token_map,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if continuation_token is None and total_chunks_migrated > 0:
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
@@ -170,19 +204,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token: {continuation_token}"
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token = (
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
continuation_token=continuation_token,
|
||||
continuation_token_map=continuation_token_map,
|
||||
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
|
||||
f"seconds. Next continuation token: {next_continuation_token}"
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
)
|
||||
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
@@ -212,14 +246,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token=next_continuation_token,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
|
||||
task_logger.info("Vespa reported no more chunks to migrate.")
|
||||
break
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
|
||||
@@ -37,6 +38,36 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
DOCUMENT_ID,
|
||||
CHUNK_ID,
|
||||
TITLE,
|
||||
TITLE_EMBEDDING,
|
||||
CONTENT,
|
||||
EMBEDDINGS,
|
||||
SOURCE_TYPE,
|
||||
METADATA_LIST,
|
||||
DOC_UPDATED_AT,
|
||||
HIDDEN,
|
||||
BOOST,
|
||||
SEMANTIC_IDENTIFIER,
|
||||
IMAGE_FILE_NAME,
|
||||
SOURCE_LINKS,
|
||||
BLURB,
|
||||
DOC_SUMMARY,
|
||||
CHUNK_CONTEXT,
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PERSONAS,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
]
|
||||
if MULTI_TENANT:
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
|
||||
|
||||
|
||||
def _extract_content_vector(embeddings: Any) -> list[float]:
|
||||
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.
|
||||
|
||||
@@ -247,6 +278,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
personas: list[int] | None = vespa_chunk.get(PERSONAS)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
@@ -296,6 +328,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
personas=personas,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Celery tasks for connector pruning."""
|
||||
|
||||
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
|
||||
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
|
||||
connector_pruning_generator_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]
|
||||
|
||||
@@ -5,14 +5,18 @@ from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
@@ -21,12 +25,16 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -57,14 +65,73 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
|
||||
|
||||
def enqueue_user_file_project_sync_task(
|
||||
*,
|
||||
celery_app: Celery,
|
||||
redis_client: Redis,
|
||||
user_file_id: str | UUID,
|
||||
tenant_id: str,
|
||||
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
|
||||
) -> bool:
|
||||
"""Enqueue a project-sync task if no matching queued task already exists."""
|
||||
queued_key = _user_file_project_sync_queued_key(user_file_id)
|
||||
|
||||
# NX+EX gives us atomic dedupe and a self-healing TTL.
|
||||
queued_guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
nx=True,
|
||||
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
if not queued_guard_set:
|
||||
return False
|
||||
|
||||
try:
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=priority,
|
||||
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
# Roll back the queued guard if task publish fails.
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
|
||||
def _visit_chunks(
|
||||
*,
|
||||
@@ -120,7 +187,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -135,7 +219,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -148,12 +246,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -161,7 +282,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -304,6 +426,12 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
@@ -557,8 +685,8 @@ def process_single_user_file_delete(
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
|
||||
task_logger.info("check_for_user_file_project_sync - Starting")
|
||||
"""Scan for user files needing project sync and enqueue per-file tasks."""
|
||||
task_logger.info("Starting")
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
@@ -570,13 +698,25 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(self.app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"Queue depth {queue_depth} exceeds "
|
||||
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
sa.and_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
@@ -586,19 +726,23 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
if not enqueue_user_file_project_sync_task(
|
||||
celery_app=self.app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
):
|
||||
skipped_guard += 1
|
||||
continue
|
||||
enqueued += 1
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"Enqueued {enqueued} "
|
||||
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -617,8 +761,10 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -630,7 +776,11 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
@@ -658,13 +808,17 @@ def process_single_user_file_project_sync(
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
@@ -672,6 +826,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.needs_persona_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
|
||||
@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -156,36 +158,7 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
logger.warning(
|
||||
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
|
||||
)
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if cleaned_doc.title and "\x00" in cleaned_doc.title:
|
||||
logger.warning(
|
||||
f"NUL characters found in document title: {cleaned_doc.title}"
|
||||
)
|
||||
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
cleaned_batch.append(sanitize_document_for_postgres(doc))
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
@@ -602,10 +575,13 @@ def connector_document_extraction(
|
||||
|
||||
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
|
||||
if hierarchy_node_batch:
|
||||
hierarchy_node_batch_cleaned = (
|
||||
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=hierarchy_node_batch,
|
||||
nodes=hierarchy_node_batch_cleaned,
|
||||
source=db_connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
@@ -624,7 +600,7 @@ def connector_document_extraction(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
|
||||
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -163,13 +162,11 @@ class ChatStateContainer:
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
*args: Additional positional arguments for func
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
arg1, arg2, kwarg1=value1
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
# Ensure state_container is passed explicitly, removing it from kwargs if present
|
||||
kwargs_with_state = {**kwargs, "state_container": state_container}
|
||||
func(emitter, *args, **kwargs_with_state)
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
@@ -45,6 +46,7 @@ from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
@@ -422,10 +424,44 @@ def convert_chat_history_basic(
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def _build_tool_call_response_history_message(
|
||||
tool_name: str,
|
||||
generated_images: list[dict] | None,
|
||||
tool_call_response: str | None,
|
||||
) -> str:
|
||||
if tool_name != IMAGE_GENERATION_TOOL_NAME:
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
if generated_images:
|
||||
llm_image_context: list[dict[str, str]] = []
|
||||
for image in generated_images:
|
||||
file_id = image.get("file_id")
|
||||
revised_prompt = image.get("revised_prompt")
|
||||
if not isinstance(file_id, str):
|
||||
continue
|
||||
|
||||
llm_image_context.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"revised_prompt": (
|
||||
revised_prompt if isinstance(revised_prompt, str) else ""
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if llm_image_context:
|
||||
return json.dumps(llm_image_context)
|
||||
|
||||
if tool_call_response:
|
||||
return tool_call_response
|
||||
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -505,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
@@ -582,10 +618,24 @@ def convert_chat_history(
|
||||
|
||||
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
|
||||
for tool_call in turn_tool_calls:
|
||||
tool_name = tool_id_to_name_map.get(
|
||||
tool_call.tool_id, "unknown"
|
||||
)
|
||||
tool_response_message = (
|
||||
_build_tool_call_response_history_message(
|
||||
tool_name=tool_name,
|
||||
generated_images=tool_call.generated_images,
|
||||
tool_call_response=tool_call.tool_call_response,
|
||||
)
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
|
||||
token_count=20, # Tiny overestimate
|
||||
message=tool_response_message,
|
||||
token_count=(
|
||||
token_counter(tool_response_message)
|
||||
if tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
else 20
|
||||
),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
|
||||
@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -30,6 +30,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
@@ -57,6 +58,7 @@ from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -68,6 +70,18 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
"""Detect XML-style marshaled tool calls emitted as plain text."""
|
||||
if not text:
|
||||
return False
|
||||
lowered = text.lower()
|
||||
return (
|
||||
"<function_calls" in lowered
|
||||
and "<invoke" in lowered
|
||||
and "<parameter" in lowered
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
@@ -122,38 +136,56 @@ def _try_fallback_tool_extraction(
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
xml_tool_call_text_detected = no_tool_calls and (
|
||||
_looks_like_xml_tool_call_payload(llm_step_result.answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
|
||||
)
|
||||
should_try_fallback = (
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
|
||||
or reasoning_but_no_answer_or_tools
|
||||
or xml_tool_call_text_detected
|
||||
)
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if (
|
||||
not extracted_tool_calls
|
||||
and llm_step_result.raw_answer
|
||||
and llm_step_result.raw_answer != llm_step_result.answer
|
||||
):
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.raw_answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
"as fallback"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
raw_answer=llm_step_result.raw_answer,
|
||||
),
|
||||
True,
|
||||
)
|
||||
@@ -171,17 +203,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
"""Build citation mapping for context files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
project_file_metadata: List of project file metadata
|
||||
file_metadata: List of context file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -189,8 +221,7 @@ def _build_project_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -210,29 +241,28 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
if context_files.file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
_create_context_files_message(context_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
if context_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
context_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
@@ -243,7 +273,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
context_files: ExtractedContextFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -257,7 +287,7 @@ def construct_message_history(
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages = _build_project_message(context_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
@@ -413,17 +443,17 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
if context_files and context_files.image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
image_files=existing_images + context_files.image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [system], [history_before_last_user], [custom_agent], [context_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
@@ -434,14 +464,14 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
# 3. Add context files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
# 5. Add last user message (with context images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
@@ -451,7 +481,42 @@ def construct_message_history(
|
||||
if reminder_message:
|
||||
result.append(reminder_message)
|
||||
|
||||
return result
|
||||
return _drop_orphaned_tool_call_responses(result)
|
||||
|
||||
|
||||
def _drop_orphaned_tool_call_responses(
|
||||
messages: list[ChatMessageSimple],
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Drop tool response messages whose tool_call_id is not in prior assistant tool calls.
|
||||
|
||||
This can happen when history truncation drops an ASSISTANT tool-call message but
|
||||
leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama)
|
||||
reject such history with an "unexpected tool call id" error.
|
||||
"""
|
||||
known_tool_call_ids: set[str] = set()
|
||||
sanitized: list[ChatMessageSimple] = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
known_tool_call_ids.add(tool_call.tool_call_id)
|
||||
sanitized.append(msg)
|
||||
continue
|
||||
|
||||
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids:
|
||||
sanitized.append(msg)
|
||||
else:
|
||||
logger.debug(
|
||||
"Dropping orphaned tool response with tool_call_id=%s while "
|
||||
"constructing message history",
|
||||
msg.tool_call_id,
|
||||
)
|
||||
continue
|
||||
|
||||
sanitized.append(msg)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _create_file_tool_metadata_message(
|
||||
@@ -480,11 +545,11 @@ def _create_file_tool_metadata_message(
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
def _create_context_files_message(
|
||||
context_files: ExtractedContextFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -492,7 +557,7 @@ def _create_project_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
@@ -503,10 +568,10 @@ def _create_project_files_message(
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
# Use pre-calculated token count from project_files
|
||||
# Use pre-calculated token count from context_files
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=project_files.total_token_count,
|
||||
token_count=context_files.total_token_count,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
@@ -517,7 +582,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
context_files: ExtractedContextFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -560,9 +625,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
if context_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
context_files.file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -580,16 +645,22 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
context_files.use_as_search_filter or context_files.file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
code_interpreter_file_generated: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
# Fetch this in a short-lived session so the long-running stream loop does
|
||||
# not pin a connection just to keep read state alive.
|
||||
with get_session_with_current_tenant() as prompt_db_session:
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(
|
||||
prompt_db_session
|
||||
)
|
||||
system_prompt = None
|
||||
custom_agent_prompt_msg = None
|
||||
|
||||
@@ -696,6 +767,7 @@ def run_llm_loop(
|
||||
),
|
||||
include_citation_reminder=should_cite_documents
|
||||
or always_cite_documents,
|
||||
include_file_reminder=code_interpreter_file_generated,
|
||||
is_last_cycle=out_of_cycles,
|
||||
)
|
||||
|
||||
@@ -714,7 +786,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt=custom_agent_prompt_msg,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
context_files=context_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
@@ -835,6 +907,18 @@ def run_llm_loop(
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
has_called_search_tool = True
|
||||
|
||||
# Track if code interpreter generated files with download links
|
||||
if (
|
||||
tool_call.tool_name == PythonTool.NAME
|
||||
and not code_interpreter_file_generated
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(tool_response.llm_facing_response)
|
||||
if parsed.get("generated_files"):
|
||||
code_interpreter_file_generated = True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from html import unescape
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -18,6 +20,7 @@ from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
@@ -56,6 +59,112 @@ from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_XML_INVOKE_BLOCK_RE = re.compile(
|
||||
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_XML_PARAMETER_RE = re.compile(
|
||||
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
|
||||
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
|
||||
|
||||
|
||||
class _XmlToolCallContentFilter:
|
||||
"""Streaming filter that strips XML-style tool call payload blocks from text."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
|
||||
def process(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
self._pending += content
|
||||
output_parts: list[str] = []
|
||||
|
||||
while self._pending:
|
||||
pending_lower = self._pending.lower()
|
||||
|
||||
if self._inside_function_calls_block:
|
||||
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
|
||||
if end_idx == -1:
|
||||
# Keep buffering until we see the close marker.
|
||||
return "".join(output_parts)
|
||||
|
||||
# Drop the whole function_calls block.
|
||||
self._pending = self._pending[
|
||||
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
|
||||
]
|
||||
self._inside_function_calls_block = False
|
||||
continue
|
||||
|
||||
start_idx = _find_function_calls_open_marker(pending_lower)
|
||||
if start_idx == -1:
|
||||
# Keep only a possible prefix of "<function_calls" in the buffer so
|
||||
# marker splits across chunks are handled correctly.
|
||||
tail_len = _matching_open_marker_prefix_len(self._pending)
|
||||
emit_upto = len(self._pending) - tail_len
|
||||
if emit_upto > 0:
|
||||
output_parts.append(self._pending[:emit_upto])
|
||||
self._pending = self._pending[emit_upto:]
|
||||
return "".join(output_parts)
|
||||
|
||||
if start_idx > 0:
|
||||
output_parts.append(self._pending[:start_idx])
|
||||
|
||||
# Enter block-stripping mode and keep scanning for close marker.
|
||||
self._pending = self._pending[start_idx:]
|
||||
self._inside_function_calls_block = True
|
||||
|
||||
return "".join(output_parts)
|
||||
|
||||
def flush(self) -> str:
|
||||
if self._inside_function_calls_block:
|
||||
# Drop any incomplete block at stream end.
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
return ""
|
||||
|
||||
remaining = self._pending
|
||||
self._pending = ""
|
||||
return remaining
|
||||
|
||||
|
||||
def _matching_open_marker_prefix_len(text: str) -> int:
|
||||
"""Return longest suffix of text that matches prefix of "<function_calls"."""
|
||||
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
|
||||
text_lower = text.lower()
|
||||
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
|
||||
|
||||
for candidate_len in range(max_len, 0, -1):
|
||||
if text_lower.endswith(marker_lower[:candidate_len]):
|
||||
return candidate_len
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
|
||||
return char is None or char in {">", " ", "\t", "\n", "\r"}
|
||||
|
||||
|
||||
def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
"""Find '<function_calls' with a valid tag boundary follower."""
|
||||
search_from = 0
|
||||
while True:
|
||||
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
|
||||
if idx == -1:
|
||||
return -1
|
||||
|
||||
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
|
||||
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
|
||||
if _is_valid_function_calls_open_follower(follower):
|
||||
return idx
|
||||
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
@@ -272,14 +381,7 @@ def _extract_tool_call_kickoffs(
|
||||
tab_index_calculated = 0
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
try:
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
|
||||
)
|
||||
tool_args = {}
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
@@ -307,8 +409,9 @@ def extract_tool_calls_from_response_text(
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
but didn't use the proper tool call format. It searches for tool calls embedded
|
||||
in response text (JSON first, then XML-like invoke blocks) that match available
|
||||
tool definitions.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
@@ -333,10 +436,9 @@ def extract_tool_calls_from_response_text(
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
prev_json_obj: dict[str, Any] | None = None
|
||||
prev_tool_call: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
@@ -364,6 +466,14 @@ def extract_tool_calls_from_response_text(
|
||||
prev_json_obj = json_obj
|
||||
prev_tool_call = matched_tool_call
|
||||
|
||||
# Some providers/models emit XML-style function calls instead of JSON objects.
|
||||
# Keep this as a fallback behind JSON extraction to preserve current behavior.
|
||||
if not matched_tool_calls:
|
||||
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_name_to_def=tool_name_to_def,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
|
||||
tool_calls.append(
|
||||
@@ -386,6 +496,88 @@ def extract_tool_calls_from_response_text(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _extract_xml_tool_calls_from_response_text(
|
||||
response_text: str,
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> list[tuple[str, dict[str, Any]]]:
|
||||
"""Extract XML-style tool calls from response text.
|
||||
|
||||
Supports formats such as:
|
||||
<function_calls>
|
||||
<invoke name="internal_search">
|
||||
<parameter name="queries" string="false">["foo"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
|
||||
invoke_attrs = invoke_match.group("attrs")
|
||||
tool_name = _extract_xml_attribute(invoke_attrs, "name")
|
||||
if not tool_name or tool_name not in tool_name_to_def:
|
||||
continue
|
||||
|
||||
tool_args: dict[str, Any] = {}
|
||||
invoke_body = invoke_match.group("body")
|
||||
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
|
||||
parameter_attrs = parameter_match.group("attrs")
|
||||
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
string_attr = _extract_xml_attribute(parameter_attrs, "string")
|
||||
tool_args[parameter_name] = _parse_xml_parameter_value(
|
||||
raw_value=parameter_match.group("value"),
|
||||
string_attr=string_attr,
|
||||
)
|
||||
|
||||
matched_tool_calls.append((tool_name, tool_args))
|
||||
|
||||
return matched_tool_calls
|
||||
|
||||
|
||||
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
"""Extract a single XML-style attribute value from a tag attribute string."""
|
||||
attr_match = re.search(
|
||||
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
|
||||
attrs,
|
||||
flags=re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Extract and parse an arguments/parameters value from a tool-call-like object.
|
||||
|
||||
Looks for "arguments" or "parameters" keys, handles JSON-string values,
|
||||
and returns a dict if successful, or None otherwise.
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
return None
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
@@ -408,13 +600,8 @@ def _try_match_json_to_tool(
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
arguments = _resolve_tool_arguments(json_obj)
|
||||
if arguments is not None:
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
@@ -422,13 +609,8 @@ def _try_match_json_to_tool(
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
arguments = _resolve_tool_arguments(func_obj)
|
||||
if arguments is not None:
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
@@ -495,6 +677,107 @@ def _extract_nested_arguments_obj(
|
||||
return None
|
||||
|
||||
|
||||
def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMessage:
|
||||
tool_calls_list: list[ToolCall] | None = None
|
||||
if msg.tool_calls:
|
||||
tool_calls_list = [
|
||||
ToolCall(
|
||||
id=tc.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc.tool_name,
|
||||
arguments=json.dumps(tc.tool_arguments),
|
||||
),
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
|
||||
return AssistantMessage(
|
||||
role="assistant",
|
||||
content=msg.message or None,
|
||||
tool_calls=tool_calls_list,
|
||||
)
|
||||
|
||||
|
||||
def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
"Tool call response message encountered but tool_call_id is not available. "
|
||||
f"Message: {msg}"
|
||||
)
|
||||
|
||||
return ToolMessage(
|
||||
role="tool",
|
||||
content=msg.message,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
|
||||
|
||||
class _HistoryMessageFormatter:
|
||||
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
def format_tool_response_message(
|
||||
self, msg: ChatMessageSimple
|
||||
) -> ToolMessage | UserMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _DefaultHistoryMessageFormatter(_HistoryMessageFormatter):
|
||||
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
|
||||
return _build_structured_assistant_message(msg)
|
||||
|
||||
def format_tool_response_message(self, msg: ChatMessageSimple) -> ToolMessage:
|
||||
return _build_structured_tool_response_message(msg)
|
||||
|
||||
|
||||
class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
|
||||
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
|
||||
if not msg.tool_calls:
|
||||
return _build_structured_assistant_message(msg)
|
||||
|
||||
tool_call_lines = [
|
||||
(
|
||||
f"[Tool Call] name={tc.tool_name} "
|
||||
f"id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
assistant_content = (
|
||||
"\n".join([msg.message, *tool_call_lines])
|
||||
if msg.message
|
||||
else "\n".join(tool_call_lines)
|
||||
)
|
||||
return AssistantMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
"Tool call response message encountered but tool_call_id is not available. "
|
||||
f"Message: {msg}"
|
||||
)
|
||||
|
||||
return UserMessage(
|
||||
role="user",
|
||||
content=f"[Tool Result] id={msg.tool_call_id}\n{msg.message}",
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_HISTORY_MESSAGE_FORMATTER = _DefaultHistoryMessageFormatter()
|
||||
_OLLAMA_HISTORY_MESSAGE_FORMATTER = _OllamaHistoryMessageFormatter()
|
||||
|
||||
|
||||
def _get_history_message_formatter(llm_config: LLMConfig) -> _HistoryMessageFormatter:
|
||||
if llm_config.model_provider == LlmProviderNames.OLLAMA_CHAT:
|
||||
return _OLLAMA_HISTORY_MESSAGE_FORMATTER
|
||||
|
||||
return _DEFAULT_HISTORY_MESSAGE_FORMATTER
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
@@ -505,6 +788,10 @@ def translate_history_to_llm_format(
|
||||
handling different message types and image files for multimodal support.
|
||||
"""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
history_message_formatter = _get_history_message_formatter(llm_config)
|
||||
# Note: cacheability is computed from pre-translation ChatMessageSimple types.
|
||||
# Some providers flatten tool history into plain assistant/user text, so this split
|
||||
# may be less semantically meaningful, but it remains safe and order-preserving.
|
||||
last_cacheable_msg_idx = -1
|
||||
all_previous_msgs_cacheable = True
|
||||
|
||||
@@ -586,39 +873,10 @@ def translate_history_to_llm_format(
|
||||
messages.append(reminder_msg)
|
||||
|
||||
elif msg.message_type == MessageType.ASSISTANT:
|
||||
tool_calls_list: list[ToolCall] | None = None
|
||||
if msg.tool_calls:
|
||||
tool_calls_list = [
|
||||
ToolCall(
|
||||
id=tc.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc.tool_name,
|
||||
arguments=json.dumps(tc.tool_arguments),
|
||||
),
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
|
||||
assistant_msg = AssistantMessage(
|
||||
role="assistant",
|
||||
content=msg.message or None,
|
||||
tool_calls=tool_calls_list,
|
||||
)
|
||||
messages.append(assistant_msg)
|
||||
messages.append(history_message_formatter.format_assistant_message(msg))
|
||||
|
||||
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
|
||||
)
|
||||
|
||||
tool_msg = ToolMessage(
|
||||
role="tool",
|
||||
content=msg.message,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
messages.append(tool_msg)
|
||||
messages.append(history_message_formatter.format_tool_response_message(msg))
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -698,7 +956,8 @@ def run_llm_step_pkt_generator(
|
||||
tool_definitions: List of tool definitions available to the LLM.
|
||||
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
|
||||
llm: Language model interface to use for generation.
|
||||
turn_index: Current turn index in the conversation.
|
||||
placement: Placement info (turn_index, tab_index, sub_turn_index) for
|
||||
positioning packets in the conversation UI.
|
||||
state_container: Container for storing chat state (reasoning, answers).
|
||||
citation_processor: Optional processor for extracting and formatting citations
|
||||
from the response. If provided, processes tokens to identify citations.
|
||||
@@ -710,7 +969,14 @@ def run_llm_step_pkt_generator(
|
||||
custom_token_processor: Optional callable that processes each token delta
|
||||
before yielding. Receives (delta, processor_state) and returns
|
||||
(modified_delta, new_processor_state). Can return None for delta to skip.
|
||||
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
|
||||
max_tokens: Optional maximum number of tokens for the LLM response.
|
||||
use_existing_tab_index: If True, use the tab_index from placement for all
|
||||
tool calls instead of auto-incrementing.
|
||||
is_deep_research: If True, treat content before tool calls as reasoning
|
||||
when tool_choice is REQUIRED.
|
||||
pre_answer_processing_time: Optional time spent processing before the
|
||||
answer started, recorded in state_container for analytics.
|
||||
timeout_override: Optional timeout override for the LLM call.
|
||||
|
||||
Yields:
|
||||
Packet: Streaming packets containing:
|
||||
@@ -736,8 +1002,15 @@ def run_llm_step_pkt_generator(
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
def _current_placement() -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
llm_msg_history = translate_history_to_llm_format(history, llm.config)
|
||||
has_reasoned = 0
|
||||
has_reasoned = False
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.debug(
|
||||
@@ -749,6 +1022,8 @@ def run_llm_step_pkt_generator(
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
accumulated_raw_answer = ""
|
||||
xml_tool_call_content_filter = _XmlToolCallContentFilter()
|
||||
|
||||
processor_state: Any = None
|
||||
|
||||
@@ -764,6 +1039,112 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
stream_start_time = time.monotonic()
|
||||
first_action_recorded = False
|
||||
|
||||
def _emit_citation_results(
|
||||
results: Generator[str | CitationInfo, None, None],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Yield packets for citation processor results (str or CitationInfo)."""
|
||||
nonlocal accumulated_answer
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=result,
|
||||
)
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
|
||||
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
|
||||
nonlocal reasoning_start
|
||||
nonlocal has_reasoned
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = True
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
|
||||
nonlocal accumulated_answer
|
||||
nonlocal accumulated_reasoning
|
||||
nonlocal answer_start
|
||||
nonlocal reasoning_start
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
accumulated_reasoning += content_chunk
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningDelta(reasoning=content_chunk),
|
||||
)
|
||||
reasoning_start = True
|
||||
return
|
||||
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
yield from _emit_citation_results(
|
||||
citation_processor.process_token(content_chunk)
|
||||
)
|
||||
else:
|
||||
accumulated_answer += content_chunk
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseDelta(content=content_chunk),
|
||||
)
|
||||
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
@@ -822,152 +1203,34 @@ def run_llm_step_pkt_generator(
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
# Treat content as reasoning when we know tool calls are coming
|
||||
accumulated_reasoning += delta.content
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.content),
|
||||
)
|
||||
reasoning_start = True
|
||||
else:
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(
|
||||
accumulated_answer
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
# Keep raw content for fallback extraction. Display content can be
|
||||
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
|
||||
accumulated_raw_answer += delta.content
|
||||
filtered_content = xml_tool_call_content_filter.process(delta.content)
|
||||
if filtered_content:
|
||||
yield from _emit_content_chunk(filtered_content)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
if filtered_content_tail:
|
||||
yield from _emit_content_chunk(filtered_content_tail)
|
||||
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
@@ -1023,50 +1286,14 @@ def run_llm_step_pkt_generator(
|
||||
|
||||
# This may happen if the custom token processor is used to modify other packets into reasoning
|
||||
# Then there won't necessarily be anything else to come after the reasoning tokens
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
|
||||
reasoning_start = False
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
# Reasoning is always first so this should use the post-incremented value of turn_index
|
||||
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
|
||||
# as clickable items and will be stripped out instead.
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
yield from _emit_citation_results(citation_processor.process_token(None))
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
@@ -1088,8 +1315,9 @@ def run_llm_step_pkt_generator(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
|
||||
),
|
||||
bool(has_reasoned),
|
||||
has_reasoned,
|
||||
)
|
||||
|
||||
|
||||
@@ -1144,4 +1372,4 @@ def run_llm_step(
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, has_reasoned = e.value
|
||||
return llm_step_result, bool(has_reasoned)
|
||||
return llm_step_result, has_reasoned
|
||||
|
||||
@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -167,21 +160,32 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
tool_calls: list[ToolCallKickoff] | None
|
||||
# Raw LLM text before any display-oriented filtering/sanitization.
|
||||
# Used for fallback tool-call extraction when providers emit calls as text.
|
||||
raw_answer: str | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
for uf in user_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
def resolve_context_user_files(
|
||||
persona: Persona,
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
def determine_search_params(
|
||||
persona_id: int,
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
extracted_context_files: ExtractedContextFiles,
|
||||
) -> SearchParams:
|
||||
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
A custom persona fully supersedes the project — project files are never
|
||||
searchable and the search tool config is entirely controlled by the
|
||||
persona. The project_id filter is only set for the default persona.
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
For the default persona inside a project:
|
||||
- Files overflow → ENABLED (vector DB scopes to these files)
|
||||
- Files fit → DISABLED (content already in prompt)
|
||||
- No files at all → DISABLED (nothing to search)
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
|
||||
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -856,6 +841,11 @@ def handle_stream_message_objects(
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
|
||||
# Release any read transaction before entering the long-running LLM stream.
|
||||
# Without this, the request-scoped session can keep a connection checked out
|
||||
# for the full stream duration.
|
||||
db_session.commit()
|
||||
|
||||
# The stream generator can resume on a different worker thread after early yields.
|
||||
# Set this right before launching the LLM loop so run_in_background copies the right context.
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
@@ -874,46 +864,54 @@ def handle_stream_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
|
||||
# are just passed in immediately anyways, but the abstraction is cleaner this way.
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
lambda emitter, state_container: run_deep_research_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
lambda emitter, state_container: run_llm_loop(
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
),
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import FILE_REMINDER
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
@@ -125,6 +126,7 @@ def calculate_reserved_tokens(
|
||||
def build_reminder_message(
|
||||
reminder_text: str | None,
|
||||
include_citation_reminder: bool,
|
||||
include_file_reminder: bool,
|
||||
is_last_cycle: bool,
|
||||
) -> str | None:
|
||||
reminder = reminder_text.strip() if reminder_text else ""
|
||||
@@ -132,6 +134,8 @@ def build_reminder_message(
|
||||
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
|
||||
if include_citation_reminder:
|
||||
reminder += "\n\n" + CITATION_REMINDER
|
||||
if include_file_reminder:
|
||||
reminder += "\n\n" + FILE_REMINDER
|
||||
reminder = reminder.strip()
|
||||
return reminder if reminder else None
|
||||
|
||||
@@ -186,7 +190,7 @@ def _build_user_information_section(
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
return USER_INFORMATION_HEADER + "\n".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
@@ -224,23 +228,21 @@ def build_system_prompt(
|
||||
system_prompt += REQUIRE_CITATION_GUIDANCE
|
||||
|
||||
if include_all_guidance:
|
||||
system_prompt += (
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
tool_sections = [
|
||||
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
|
||||
INTERNAL_SEARCH_GUIDANCE,
|
||||
WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
),
|
||||
OPEN_URLS_GUIDANCE,
|
||||
PYTHON_TOOL_GUIDANCE,
|
||||
GENERATE_IMAGE_GUIDANCE,
|
||||
MEMORY_GUIDANCE,
|
||||
]
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
|
||||
return system_prompt
|
||||
|
||||
if tools:
|
||||
system_prompt += TOOL_SECTION_HEADER
|
||||
|
||||
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
|
||||
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
|
||||
@@ -250,12 +252,14 @@ def build_system_prompt(
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
tool_guidance_sections: list[str] = []
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
|
||||
|
||||
# These are not included at the Tool level because the ordering may matter.
|
||||
if has_internal_search or include_all_guidance:
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
site_disabled_guidance = ""
|
||||
@@ -265,20 +269,23 @@ def build_system_prompt(
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
tool_guidance_sections.append(
|
||||
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
|
||||
)
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
|
||||
|
||||
if has_python or include_all_guidance:
|
||||
system_prompt += PYTHON_TOOL_GUIDANCE
|
||||
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
|
||||
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
tool_guidance_sections.append(MEMORY_GUIDANCE)
|
||||
|
||||
if tool_guidance_sections:
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
SMTP_USER = os.environ.get("SMTP_USER") or ""
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
|
||||
@@ -251,7 +251,9 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
|
||||
)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
|
||||
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
|
||||
)
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
@@ -263,6 +265,18 @@ OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
|
||||
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
|
||||
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
|
||||
OPENSEARCH_EXPLAIN_ENABLED = (
|
||||
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
|
||||
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
|
||||
# existing indices need reindexing after a change.
|
||||
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
@@ -270,6 +284,9 @@ OPENSEARCH_PROFILING_DISABLED = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
# in case that doesn't work for whatever reason.
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
@@ -277,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
@@ -625,6 +648,14 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
|
||||
# When False (default), only groups found in site role assignments are synced.
|
||||
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
|
||||
# connector_specific_config.
|
||||
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
|
||||
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
|
||||
)
|
||||
|
||||
BLOB_STORAGE_SIZE_THRESHOLD = int(
|
||||
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
@@ -157,6 +157,25 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
# How long a queued user-file-project-sync task remains valid.
|
||||
# Should be short enough to discard stale queue entries under load while still
|
||||
# allowing workers enough time to pick up new tasks.
|
||||
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Max queue depth before user-file-project-sync producers stop enqueuing.
|
||||
# This applies backpressure when workers are falling behind.
|
||||
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
@@ -443,8 +462,12 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
@@ -156,10 +157,7 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
|
||||
return False
|
||||
|
||||
# For shared drive content, the root has id == driveId
|
||||
if drive_id and folder_id == drive_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(drive_id and folder_id == drive_id)
|
||||
|
||||
|
||||
def _public_access() -> ExternalAccess:
|
||||
@@ -616,6 +614,16 @@ class GoogleDriveConnector(
|
||||
# empty parents due to permission limitations)
|
||||
# Check shared drive root first (simple ID comparison)
|
||||
if _is_shared_drive_root(folder):
|
||||
# files().get() returns 'Drive' for shared drive roots;
|
||||
# fetch the real name via drives().get().
|
||||
# Try both the retriever and admin since the admin may
|
||||
# not have access to private shared drives.
|
||||
drive_name = self._get_shared_drive_name(
|
||||
current_id, file.user_email
|
||||
)
|
||||
if drive_name:
|
||||
node.display_name = drive_name
|
||||
node.node_type = HierarchyNodeType.SHARED_DRIVE
|
||||
reached_terminal = True
|
||||
break
|
||||
|
||||
@@ -691,6 +699,15 @@ class GoogleDriveConnector(
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
|
||||
"""Fetch the name of a shared drive, trying both the retriever and admin."""
|
||||
for email in {retriever_email, self.primary_admin_email}:
|
||||
svc = get_drive_service(self.creds, email)
|
||||
name = get_shared_drive_name(svc, drive_id)
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_drive_ids(self) -> set[str]:
|
||||
return self._get_all_drives_for_user(self.primary_admin_email)
|
||||
|
||||
|
||||
@@ -154,6 +154,26 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return HIERARCHY_FIELDS
|
||||
|
||||
|
||||
def get_shared_drive_name(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
) -> str | None:
|
||||
"""Fetch the actual name of a shared drive via the drives().get() API.
|
||||
|
||||
The files().get() API returns 'Drive' as the name for shared drive root
|
||||
folders. Only drives().get() returns the real user-assigned name.
|
||||
"""
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id, fields="name").execute()
|
||||
return drive.get("name")
|
||||
except HttpError as e:
|
||||
if e.resp.status in (403, 404):
|
||||
logger.debug(f"Cannot access drive {drive_id}: {e}")
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def get_external_access_for_folder(
|
||||
folder: GoogleDriveFileType,
|
||||
google_domain: str,
|
||||
|
||||
@@ -16,6 +16,22 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
|
||||
|
||||
|
||||
def _is_rate_limit_error(error: HttpError) -> bool:
|
||||
"""Google sometimes returns rate-limit errors as 403 with reason
|
||||
'userRateLimitExceeded' instead of 429. This helper detects both."""
|
||||
if error.resp.status == 429:
|
||||
return True
|
||||
if error.resp.status != 403:
|
||||
return False
|
||||
error_details = getattr(error, "error_details", None) or []
|
||||
for detail in error_details:
|
||||
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
|
||||
return True
|
||||
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
@@ -57,7 +73,7 @@ def _execute_with_retry(request: Any) -> Any:
|
||||
except HttpError as error:
|
||||
attempt += 1
|
||||
|
||||
if error.resp.status == 429:
|
||||
if _is_rate_limit_error(error):
|
||||
# Attempt to get 'Retry-After' from headers
|
||||
retry_after = error.resp.get("Retry-After")
|
||||
if retry_after:
|
||||
@@ -140,16 +156,16 @@ def _execute_single_retrieval(
|
||||
)
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
else:
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
96
backend/onyx/connectors/microsoft_graph_env.py
Normal file
96
backend/onyx/connectors/microsoft_graph_env.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
|
||||
|
||||
The office365 library's GraphClient requires an ``AzureEnvironment`` string
|
||||
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
|
||||
cloud. Our connectors instead expose free-text ``authority_host`` and
|
||||
``graph_api_host`` fields so the frontend doesn't need to know about SDK
|
||||
internals.
|
||||
|
||||
This module bridges the gap: given the two host URLs the user configured, it
|
||||
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
|
||||
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
|
||||
"""
|
||||
|
||||
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
|
||||
|
||||
class MicrosoftGraphEnvironment(BaseModel):
|
||||
"""One row of the inverse mapping."""
|
||||
|
||||
environment: str
|
||||
graph_host: str
|
||||
authority_host: str
|
||||
sharepoint_domain_suffix: str
|
||||
|
||||
|
||||
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.Global,
|
||||
graph_host="https://graph.microsoft.com",
|
||||
authority_host="https://login.microsoftonline.com",
|
||||
sharepoint_domain_suffix="sharepoint.com",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.USGovernmentHigh,
|
||||
graph_host="https://graph.microsoft.us",
|
||||
authority_host="https://login.microsoftonline.us",
|
||||
sharepoint_domain_suffix="sharepoint.us",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.USGovernmentDoD,
|
||||
graph_host="https://dod-graph.microsoft.us",
|
||||
authority_host="https://login.microsoftonline.us",
|
||||
sharepoint_domain_suffix="sharepoint.us",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.China,
|
||||
graph_host="https://microsoftgraph.chinacloudapi.cn",
|
||||
authority_host="https://login.chinacloudapi.cn",
|
||||
sharepoint_domain_suffix="sharepoint.cn",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.Germany,
|
||||
graph_host="https://graph.microsoft.de",
|
||||
authority_host="https://login.microsoftonline.de",
|
||||
sharepoint_domain_suffix="sharepoint.de",
|
||||
),
|
||||
]
|
||||
|
||||
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
|
||||
env.graph_host: env for env in _ENVIRONMENTS
|
||||
}
|
||||
|
||||
|
||||
def resolve_microsoft_environment(
|
||||
graph_api_host: str,
|
||||
authority_host: str,
|
||||
) -> MicrosoftGraphEnvironment:
|
||||
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
|
||||
|
||||
Raises ``ConnectorValidationError`` when the combination is unknown or
|
||||
internally inconsistent (e.g. a GCC-High graph host paired with a
|
||||
commercial authority host).
|
||||
"""
|
||||
graph_api_host = graph_api_host.rstrip("/")
|
||||
authority_host = authority_host.rstrip("/")
|
||||
|
||||
env = _GRAPH_HOST_INDEX.get(graph_api_host)
|
||||
if env is None:
|
||||
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
|
||||
raise ConnectorValidationError(
|
||||
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
|
||||
f"Recognised hosts: {known}"
|
||||
)
|
||||
|
||||
if env.authority_host != authority_host:
|
||||
raise ConnectorValidationError(
|
||||
f"Authority host '{authority_host}' is inconsistent with "
|
||||
f"graph API host '{graph_api_host}'. "
|
||||
f"Expected authority host '{env.authority_host}' "
|
||||
f"for the {env.environment} environment."
|
||||
)
|
||||
|
||||
return env
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -167,6 +168,14 @@ class DocumentBase(BaseModel):
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
key: [str(item) for item in val] if isinstance(val, list) else str(val)
|
||||
for key, val in v.items()
|
||||
}
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ from dateutil import parser
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -258,3 +259,21 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Very basic validation, we could do more here
|
||||
"""
|
||||
if not self.base_url.startswith("https://") and not self.base_url.startswith(
|
||||
"http://"
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Base URL must start with https:// or http://"
|
||||
)
|
||||
|
||||
try:
|
||||
get_all_post_ids(self.slab_bot_token)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -50,12 +51,15 @@ class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
todo_team_ids: list[str] | None = None
|
||||
|
||||
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -63,12 +67,19 @@ class TeamsConnector(
|
||||
# are not necessarily guaranteed to be unique
|
||||
teams: list[str] = [],
|
||||
max_workers: int = MAX_WORKERS,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
) -> None:
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.max_workers = max_workers
|
||||
self.requested_team_list: list[str] = teams
|
||||
|
||||
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
|
||||
self._azure_environment = resolved_env.environment
|
||||
self.authority_host = resolved_env.authority_host
|
||||
self.graph_api_host = resolved_env.graph_host
|
||||
|
||||
# impls for BaseConnector
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -76,7 +87,7 @@ class TeamsConnector(
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
|
||||
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
|
||||
authority_url = f"{self.authority_host}/{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
@@ -91,7 +102,7 @@ class TeamsConnector(
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
scopes=[f"{self.graph_api_host}/.default"]
|
||||
)
|
||||
|
||||
if not isinstance(token, dict):
|
||||
@@ -99,7 +110,9 @@ class TeamsConnector(
|
||||
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
self.graph_client = GraphClient(
|
||||
_acquire_token_func, environment=self._azure_environment
|
||||
)
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
@@ -905,13 +906,15 @@ def convert_slack_score(slack_score: float) -> float:
|
||||
def slack_retrieval(
|
||||
query: ChunkIndexRequest,
|
||||
access_token: str,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
|
||||
entities: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
slack_event_context: SlackContext | None = None,
|
||||
bot_token: str | None = None, # Add bot token parameter
|
||||
team_id: str | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB query (no session needed)
|
||||
search_settings: SearchSettings | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Main entry point for Slack federated search with entity filtering.
|
||||
@@ -925,7 +928,7 @@ def slack_retrieval(
|
||||
Args:
|
||||
query: Search query object
|
||||
access_token: User OAuth access token
|
||||
db_session: Database session
|
||||
db_session: Database session (optional if search_settings provided)
|
||||
connector: Federated connector detail (unused, kept for backwards compat)
|
||||
entities: Connector-level config (entity filtering configuration)
|
||||
limit: Maximum number of results
|
||||
@@ -1153,7 +1156,10 @@ def slack_retrieval(
|
||||
|
||||
# chunk index docs into doc aware chunks
|
||||
# a single index doc can get split into multiple chunks
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or search_settings must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedder = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -18,8 +18,10 @@ from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -38,10 +40,11 @@ def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
query: str | None = None,
|
||||
llm: LLM | None = None,
|
||||
@@ -49,18 +52,19 @@ def _build_index_filters(
|
||||
# Assistant knowledge filters
|
||||
attached_document_ids: list[str] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
# Pre-fetched ACL filters (skips DB query when provided)
|
||||
acl_filters: list[str] | None = None,
|
||||
) -> IndexFilters:
|
||||
if auto_detect_filters and (llm is None or query is None):
|
||||
raise RuntimeError("LLM and query are required for auto detect filters")
|
||||
|
||||
base_filters = user_provided_filters or BaseFilters()
|
||||
|
||||
if (
|
||||
user_provided_filters
|
||||
and user_provided_filters.document_set is None
|
||||
and persona_document_sets is not None
|
||||
):
|
||||
base_filters.document_set = persona_document_sets
|
||||
document_set_filter = (
|
||||
base_filters.document_set
|
||||
if base_filters.document_set is not None
|
||||
else persona_document_sets
|
||||
)
|
||||
|
||||
time_filter = base_filters.time_cutoff or persona_time_cutoff
|
||||
source_filter = base_filters.source_type
|
||||
@@ -103,15 +107,21 @@ def _build_index_filters(
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
user_acl_filters = acl_filters
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or acl_filters must be provided")
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=persona_document_sets,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
tags=base_filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
@@ -252,11 +262,17 @@ def search_pipeline(
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
@@ -287,6 +303,7 @@ def search_pipeline(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
@@ -297,6 +314,7 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
acl_filters=acl_filters,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
@@ -315,6 +333,8 @@ def search_pipeline(
|
||||
user_id=user.id if user else None,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
|
||||
)
|
||||
|
||||
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean
|
||||
|
||||
@@ -14,9 +14,11 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -50,9 +52,14 @@ def combine_retrieval_results(
|
||||
def _embed_and_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_query_embedding(query_request.query, db_session)
|
||||
query_embedding = get_query_embedding(
|
||||
query_request.query,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
|
||||
|
||||
@@ -78,7 +85,9 @@ def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
@@ -88,14 +97,22 @@ def search_chunks(
|
||||
else None
|
||||
)
|
||||
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
# Federated retrieval — use pre-fetched if available, otherwise query DB
|
||||
if prefetched_federated_retrieval_infos is not None:
|
||||
federated_retrieval_infos = prefetched_federated_retrieval_infos
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError(
|
||||
"Either db_session or prefetched_federated_retrieval_infos "
|
||||
"must be provided"
|
||||
)
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -114,7 +131,10 @@ def search_chunks(
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(_embed_and_search, (query_request, document_index, db_session))
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
|
||||
@@ -64,23 +64,34 @@ def inference_section_from_single_chunk(
|
||||
)
|
||||
|
||||
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
def get_query_embeddings(
|
||||
queries: list[str],
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[Embedding]:
|
||||
if embedding_model is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or embedding_model must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
return get_query_embeddings([query], db_session)[0]
|
||||
def get_query_embedding(
|
||||
query: str,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> Embedding:
|
||||
return get_query_embeddings(
|
||||
[query], db_session=db_session, embedding_model=embedding_model
|
||||
)[0]
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user