Compare commits

..

82 Commits

Author SHA1 Message Date
Nik
48e7428069 chore(helm): remove broken code-interpreter dependency
The code-interpreter Helm chart repo at
https://onyx-dot-app.github.io/code-interpreter/ returns 404,
causing ct lint to fail in CI. Remove it from Chart.yaml
dependencies, Chart.lock, ct.yaml chart-repos, and the CI
workflow's helm repo add step.
2026-02-19 20:17:14 -08:00
Evan Lohn
fc6a37850b feat: delta sync sharepoint (#8532)
Co-authored-by: CE11-Kishan <CE11-Kishan@users.noreply.github.com>
2026-02-20 03:26:54 +00:00
Raunak Bhagat
aa6fec3d58 Fix/pat UI (#8617) 2026-02-19 19:23:26 -08:00
Raunak Bhagat
efa6005e36 feat(opal): Add widthVariant to Interactive.Container (#8610) 2026-02-20 03:14:30 +00:00
Nikolas Garza
921bfc72f4 fix(helm): route /openapi.json to api_server in nginx config (#8612) 2026-02-19 19:05:14 -08:00
Justin Tahara
812603152d feat(web): FE Changes for Brave Web Search 3/3 (#8597) 2026-02-20 02:43:30 +00:00
Raunak Bhagat
6779d8fbd7 feat(opal): Add Content layout component (#8534)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 02:35:27 +00:00
Justin Tahara
2c9826e4a9 feat(web): Logical Hardening for Brave Web Search 2/3 (#8595) 2026-02-20 02:09:49 +00:00
Justin Tahara
5b54687077 feat(web): Initial Framework for Brave Web Search 1/3 (#8594) 2026-02-20 01:38:19 +00:00
Raunak Bhagat
0f7e2ee674 feat(opal): Add Tag component and resync colors with Figma (#8533)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 01:14:52 +00:00
Danelegend
ea466648d9 feat: file preview from llm (#8604) 2026-02-20 00:47:03 +00:00
Evan Lohn
a402911ee6 feat: sharepoint scalability 1 (#8531) 2026-02-20 00:41:48 +00:00
Wenxi
7ae9ba807d fix(manage-users): exclude slack users from /users list (#8602) 2026-02-19 23:41:25 +00:00
Nikolas Garza
1f79223c42 feat(ods): add --continue flag and cp alias to cherry-pick (#8601) 2026-02-19 22:31:06 +00:00
Danelegend
c0c2247d5a feat: Modal Header no icon (#8596) 2026-02-19 21:59:18 +00:00
Danelegend
2989ceda41 feat: add file formatting reminder (#8524) 2026-02-19 21:51:06 +00:00
Wenxi
c825f5eca6 feat: whitelist invite setting, allow users to register and invite, new accoount blocked page (#8527) 2026-02-19 20:16:50 +00:00
Jamison Lahman
a8965def79 chore(playwright): fix chat scrolling non-determinism (#8584) 2026-02-19 19:26:09 +00:00
Jamison Lahman
59e1ad51ba chore(fe): fix drop-down overflow in API Key modal (#8574) 2026-02-19 19:15:10 +00:00
Jamison Lahman
0e70a8f826 chore(fe): remove close button from image gen tooltip (#8585) 2026-02-19 18:31:02 +00:00
SubashMohan
0891737dfd fix: update SourceTag component to use variant prop for sizing (#8582) 2026-02-19 17:27:04 +00:00
victoria reese
5a20112670 fix: define fallback only on custom metrics (#8566) 2026-02-19 09:19:44 -08:00
SubashMohan
584f2e2638 fix(ui): Fix admin UI layout, sticky headers, LLM filtering, and project view issues (#8496) 2026-02-19 14:02:22 +00:00
Yuhong Sun
aa24b16ec1 chore: Opensearch tuning (#8518) 2026-02-19 05:43:45 +00:00
acaprau
50aa9d7df6 feat(opensearch): Display percentage progress in the migration page (#8575) 2026-02-19 05:33:29 +00:00
acaprau
bfda586054 chore(opensearch): Make the migration visit from 4 independent Vespa slices concurrently (#8570) 2026-02-19 04:26:59 +00:00
roshan
e04392fbb1 feat(craft): shareable webapp URLs with two-tier access control (#8510)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-19 04:04:21 +00:00
Wenxi
e46c6c5175 fix: init all celery tasks for autodiscovery (#8539) 2026-02-19 03:45:40 +00:00
Nikolas Garza
f59792b4ac feat(metrics): add SQLAlchemy connection pool Prometheus metrics (#8543) 2026-02-19 02:32:42 +00:00
Justin Tahara
973b9456e9 fix(vertex ai): Sort Model Names (#8572) 2026-02-19 02:24:47 +00:00
Justin Tahara
aa8d126513 fix(ollama): Ollama Chat fixes (#8522) 2026-02-19 01:37:07 +00:00
Jamison Lahman
a6da5add49 chore(fe): header text content wraps when responsive (#8565) 2026-02-19 01:25:12 +00:00
Jamison Lahman
3356f90437 chore(fe): Button handles empty string as text, use in header (#8563) 2026-02-19 01:24:10 +00:00
Danelegend
27c254ecf9 fix: /llm/provider route returns all providers (#8545) 2026-02-19 00:13:27 +00:00
Jamison Lahman
09678b3c8e chore(fe): whitelabeling header nits (#8561) 2026-02-18 23:39:54 +00:00
Justin Tahara
ecdb962e24 chore(llm): Cleaning up arg parsing (#8555) 2026-02-18 23:06:19 +00:00
Justin Tahara
63b9b91565 chore(llm): Extract _close_reasoning (#8550) 2026-02-18 22:59:07 +00:00
Justin Tahara
14770e6e90 chore(llm): Extract citation flush helper (#8554) 2026-02-18 22:31:51 +00:00
Justin Tahara
14807d986a chore(llm): Using bool for has_reasoned (#8549) 2026-02-18 22:18:09 +00:00
Justin Tahara
290eb98020 chore(llm): Extract _make_placement (#8552) 2026-02-18 21:58:58 +00:00
Nikolas Garza
fe6fa3d034 feat(scim): add SCIM Group CRUD endpoints (#8456) 2026-02-18 21:52:56 +00:00
Nikolas Garza
2a60a02e0e feat(scim): add admin SCIM token management API (#8538) 2026-02-18 21:41:09 +00:00
Justin Tahara
3bcd666e90 chore(llm): Cleaning up _extract_tool_call_kickoffs (#8548) 2026-02-18 21:39:36 +00:00
Jamison Lahman
684013732c chore(fe): update human message size (#8547) 2026-02-18 21:27:02 +00:00
Nikolas Garza
367dcb8f8b feat(scim): add SCIM User CRUD endpoints (#8455) 2026-02-18 21:26:52 +00:00
Justin Tahara
59dfed0bc8 chore(llm): Cleaning up Docstring (#8546) 2026-02-18 21:25:32 +00:00
Jamison Lahman
7a719b54bb chore(playwright): worker-aware users for test isolation (#8544) 2026-02-18 21:13:07 +00:00
Jamison Lahman
25ef5ff010 chore(fe): update SourceTag tag size (#8540) 2026-02-18 20:49:11 +00:00
Danelegend
53f9f042a1 fix: model config not populating flow during sync (#8542) 2026-02-18 20:08:59 +00:00
Jamison Lahman
3469f0c979 chore(gha): rm nightly license scan workflow (#8541) 2026-02-18 19:55:23 +00:00
Jamison Lahman
f688efbcd6 chore(gha): update helm/chart-testing-action version (#8536) 2026-02-18 11:21:17 -08:00
Jamison Lahman
250658a8b2 chore(playwright): screenshots wait for animations (#8530) 2026-02-18 18:33:33 +00:00
SubashMohan
5150ffc3e0 feat(chat): new chat sharing UI (#8471) 2026-02-18 15:15:34 +00:00
Jamison Lahman
858c1dbe4a chore(playwright): chat rendering tests (#8526) 2026-02-18 05:41:54 +00:00
Evan Lohn
a8e7353227 fix: shared drive node names (#8520) 2026-02-18 05:34:00 +00:00
Nikolas Garza
343cda35cb feat(observability): add production Prometheus instrumentation module (#8503) 2026-02-18 05:09:32 +00:00
Wenxi
1cbe47d85e chore: increase firecrawl test timeout to 60s for e2e test (#8529) 2026-02-18 05:02:45 +00:00
Jamison Lahman
221658132a chore(playwright): skip visual regression report on cancelled (#8528)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-17 20:56:22 -08:00
Jamison Lahman
fe8fb9eb75 fix(fe): center-align modals relative to chat container (#8517) 2026-02-18 04:27:38 +00:00
Wenxi
f7925584b8 fix: create release notifications on cloud (#8525) 2026-02-18 03:47:23 +00:00
Nikolas Garza
00b0e15ed7 feat(scim): add SCIM 2.0 service discovery endpoints (#8454) 2026-02-18 02:07:49 +00:00
Jamison Lahman
c2968e3bfe chore(playwright): update skill re: user isolation best-practices (#8521) 2026-02-17 17:52:03 -08:00
Justin Tahara
978f0a9d35 fix(ollama): Cleaning up DeepSeek (#8519) 2026-02-18 01:44:34 +00:00
Danelegend
410340fe37 chore: add linguist-language (#8515)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-18 00:57:27 +00:00
Wenxi
a0545c7eb3 fix: open_url broken on non-normalized urls and enable web crawl tests (#8508) 2026-02-18 00:41:21 +00:00
Jamison Lahman
aa46a8bba2 chore(gha): only run zizmor when .github/ changes (#8516) 2026-02-17 16:25:13 -08:00
Jamison Lahman
cd5aaa0302 chore(playwright): prefer absolute imports (#8511) 2026-02-17 16:22:51 -08:00
roshan
db33efaeaa fix: Enable search UI in Chrome extension side panel (#8486)
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-17 23:59:05 +00:00
Justin Tahara
c28d37dff8 chore(playwright): Cleanup MCP Tests (#8512) 2026-02-17 23:57:22 +00:00
Jamison Lahman
696e72bcbb chore(playwright): organize tests into directories (#8509) 2026-02-17 15:37:35 -08:00
Nikolas Garza
cfdac8083a feat(scim): add SCIM bearer token authentication (#8423) 2026-02-17 23:28:00 +00:00
Jamison Lahman
781aab67fa chore(cursor): playwright SKILL.md (#8506) 2026-02-17 15:12:01 -08:00
Justin Tahara
b14d357d55 chore(mcp): Adding more Playwright Tests (#8452) 2026-02-17 22:55:35 +00:00
Nikolas Garza
60bc1ce8a1 feat(scim): add SCIM database CRUD operations (DAL) (#8424) 2026-02-17 22:43:54 +00:00
roshan
c89e82ee58 feat(craft): ACP session persistence across sandbox sleep/wake (#8466)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:32:49 +00:00
Evan Lohn
529ab8179f chore: undeprecate doc sets (#8505) 2026-02-17 21:56:49 +00:00
Nikolas Garza
19716874b2 fix(ee): small ux fixes for licensing (#8498) 2026-02-17 21:55:06 +00:00
Justin Tahara
1ac3b8515d chore(playwright): Cleanup for LLM Tests (#8504) 2026-02-17 21:38:24 +00:00
Justin Tahara
0d0c8580ca fix(helm): Log Level Passthrough for Celery (#8495) 2026-02-17 21:15:46 +00:00
roshan
96a38dcc06 fix(craft): ephemeral ACP clients to prevent multi-replica session corruption (#8465)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 21:10:21 +00:00
Jamison Lahman
6fe72e5524 chore(playwright): use user for llm runtime tests (#8502) 2026-02-17 20:21:28 +00:00
Jamison Lahman
1f4ee4d550 chore(playwright): hide toast elements in screenshots (#8501) 2026-02-17 19:42:54 +00:00
268 changed files with 14544 additions and 4345 deletions

1
.claude/skills Symbolic link
View File

@@ -0,0 +1 @@
../.cursor/skills

View File

@@ -0,0 +1,248 @@
---
name: playwright-e2e-tests
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
---
# Playwright E2E Tests
## Project Layout
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
- **Config**: `web/playwright.config.ts`
- **Utilities**: `web/tests/e2e/utils/`
- **Constants**: `web/tests/e2e/constants.ts`
- **Global setup**: `web/tests/e2e/global-setup.ts`
- **Output**: `web/output/playwright/`
## Imports
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
```
All new files should be `.ts`, not `.js`.
## Running Tests
```bash
# Run a specific test file
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
# Run a specific project
npx playwright test --project admin
npx playwright test --project exclusive
```
## Test Projects
| Project | Description | Parallelism |
|---------|-------------|-------------|
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
## Authentication
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
- Server readiness check (polls health endpoint, 60s timeout)
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
- Setting display name to `"worker"` for each worker user
- Promoting admin2 to admin role
- Ensuring a public LLM provider exists
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
When a test needs a different user, use API-based login — never drive the login UI:
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAs(page, "admin2");
// Log in as the worker-specific user (preferred for test isolation):
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
```
## Test Structure
Tests start pre-authenticated as admin — navigate and test directly:
```typescript
import { test, expect } from "@playwright/test";
test.describe("Feature Name", () => {
test("should describe expected behavior clearly", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Already authenticated as admin — go straight to testing
});
});
```
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
```typescript
import { test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
});
```
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
## Key Utilities
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
Backend API client for test setup/teardown. Key methods:
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
- **Chat**: `createChatSession()`, `deleteChatSession()`
### `chatActions` (`@tests/e2e/utils/chatActions`)
- `sendMessage(page, message)` — sends a message and waits for AI response
- `startNewChat(page)` — clicks new-chat button and waits for intro
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
- `switchModel(page, modelName)` — switches LLM model via popover
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
- `expectElementScreenshot(locator, { name, mask?, hide? })`
- Controlled by `VISUAL_REGRESSION=true` env var
### `theme` (`@tests/e2e/utils/theme`)
- `THEMES``["light", "dark"] as const` array for iterating over both themes
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
```typescript
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
for (const theme of THEMES) {
test.describe(`Feature (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("renders correctly", async ({ page }) => {
await page.goto("/app");
await expectScreenshot(page, { name: `feature-${theme}` });
});
});
}
```
### `tools` (`@tests/e2e/utils/tools`)
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
- `openActionManagement(page)` — opens the tool management popover
## Locator Strategy
Use locators in this priority order:
1. **`data-testid` / `aria-label`** — preferred for Onyx components
```typescript
page.getByTestId("AppSidebar/new-session")
page.getByLabel("admin-page-title")
```
2. **Role-based** — for standard HTML elements
```typescript
page.getByRole("button", { name: "Create" })
page.getByRole("dialog")
```
3. **Text/Label** — for visible text content
```typescript
page.getByText("Custom Assistant")
page.getByLabel("Email")
```
4. **CSS selectors** — last resort, only when above won't work
```typescript
page.locator('input[name="name"]')
page.locator("#onyx-chat-input-textarea")
```
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
## Assertions
Use web-first assertions — they auto-retry until the condition is met:
```typescript
// Visibility
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
// Text content
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
// Count
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
// URL
await expect(page).toHaveURL(/chatId=/);
// Element state
await expect(toggle).toBeChecked();
await expect(button).toBeEnabled();
```
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
## Waiting Strategy
```typescript
// Wait for load state after navigation
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Wait for specific element
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
// Wait for URL change
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
// Wait for network response
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
```
## Best Practices
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does

View File

@@ -1,151 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -41,8 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"
@@ -92,7 +91,6 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -22,6 +22,9 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
@@ -300,6 +303,7 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
@@ -589,7 +593,10 @@ jobs:
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: always() && github.event_name == 'pull_request'
if: >-
always() &&
github.event_name == 'pull_request' &&
needs.playwright-tests.result != 'cancelled'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:

View File

@@ -5,6 +5,8 @@ on:
branches: ["main"]
pull_request:
branches: ["**"]
paths:
- ".github/**"
permissions: {}
@@ -21,29 +23,18 @@ jobs:
with:
persist-credentials: false
- name: Detect changes
id: filter
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
with:
filters: |
zizmor:
- '.github/**'
- name: Install the latest version of uv
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Run zizmor
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@
.zed
.cursor
!/.cursor/mcp.json
!/.cursor/skills/
# macos
.DS_store

View File

@@ -0,0 +1,32 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -0,0 +1,31 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -263,9 +263,15 @@ def refresh_license_cache(
try:
payload = verify_license_signature(license_record.license_data)
# Derive source from payload: manual licenses lack stripe_customer_id
source: LicenseSource = (
LicenseSource.AUTO_FETCH
if payload.stripe_customer_id
else LicenseSource.MANUAL_UPLOAD
)
return update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
source=source,
tenant_id=tenant_id,
)
except ValueError as e:

604
backend/ee/onyx/db/scim.py Normal file
View File

@@ -0,0 +1,604 @@
"""SCIM Data Access Layer.
All database operations for SCIM provisioning — token management, user
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
Usage from FastAPI::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.post("/tokens")
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
token = dal.create_token(name=..., hashed_token=..., ...)
dal.commit()
return token
Usage from background tasks::
with ScimDAL.from_tenant("tenant_abc") as dal:
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
dal.commit()
"""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ScimDAL(DAL):
"""Data Access Layer for SCIM provisioning operations.
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
when you want to persist changes. This follows the existing ``_no_commit``
convention and lets callers batch multiple operations into one transaction.
"""
# ------------------------------------------------------------------
# Token operations
# ------------------------------------------------------------------
def create_token(
self,
name: str,
hashed_token: str,
token_display: str,
created_by_id: UUID,
) -> ScimToken:
"""Create a new SCIM bearer token.
Only one token is active at a time — this method automatically revokes
all existing active tokens before creating the new one.
"""
# Revoke any currently active tokens
active_tokens = list(
self._session.scalars(
select(ScimToken).where(ScimToken.is_active.is_(True))
).all()
)
for t in active_tokens:
t.is_active = False
token = ScimToken(
name=name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=created_by_id,
)
self._session.add(token)
self._session.flush()
return token
def get_active_token(self) -> ScimToken | None:
"""Return the single currently active token, or None."""
return self._session.scalar(
select(ScimToken).where(ScimToken.is_active.is_(True))
)
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
"""Look up a token by its SHA-256 hash."""
return self._session.scalar(
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
)
def revoke_token(self, token_id: int) -> None:
"""Deactivate a token by ID.
Raises:
ValueError: If the token does not exist.
"""
token = self._session.get(ScimToken, token_id)
if not token:
raise ValueError(f"SCIM token with id {token_id} not found")
token.is_active = False
def update_token_last_used(self, token_id: int) -> None:
"""Update the last_used_at timestamp for a token."""
token = self._session.get(ScimToken, token_id)
if token:
token.last_used_at = func.now() # type: ignore[assignment]
# ------------------------------------------------------------------
# User mapping operations
# ------------------------------------------------------------------
def create_user_mapping(
self,
external_id: str,
user_id: UUID,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_user_mapping_by_external_id(
self, external_id: str
) -> ScimUserMapping | None:
"""Look up a user mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
)
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
"""Look up a user mapping by the Onyx user ID."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
)
def list_user_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimUserMapping], int]:
"""List user mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimUserMapping)
.order_by(ScimUserMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def update_user_mapping_external_id(
self,
mapping_id: int,
external_id: str,
) -> ScimUserMapping:
"""Update the external ID on a user mapping.
Raises:
ValueError: If the mapping does not exist.
"""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
mapping.external_id = external_id
return mapping
def delete_user_mapping(self, mapping_id: int) -> None:
"""Delete a user mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# User query operations
# ------------------------------------------------------------------
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
"""Fetch a user by email (case-insensitive)."""
return self._session.scalar(
select(User).where(func.lower(User.email) == func.lower(email))
)
def add_user(self, user: User) -> None:
"""Add a new user to the session and flush to assign an ID."""
self._session.add(user)
self._session.flush()
def update_user(
self,
user: User,
*,
email: str | None = None,
is_active: bool | None = None,
personal_name: str | None = None,
) -> None:
"""Update user attributes. Only sets fields that are provided."""
if email is not None:
user.email = email
if is_active is not None:
user.is_active = is_active
if personal_name is not None:
user.personal_name = personal_name
def deactivate_user(self, user: User) -> None:
"""Mark a user as inactive."""
user.is_active = False
def list_users(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
# Count total matching rows first, then paginate. SCIM uses 1-based
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
# ------------------------------------------------------------------
# Group mapping operations
# ------------------------------------------------------------------
def create_group_mapping(
self,
external_id: str,
user_group_id: int,
) -> ScimGroupMapping:
"""Create a mapping between a SCIM externalId and an Onyx user group."""
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_group_mapping_by_external_id(
self, external_id: str
) -> ScimGroupMapping | None:
"""Look up a group mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
)
def get_group_mapping_by_group_id(
self, user_group_id: int
) -> ScimGroupMapping | None:
"""Look up a group mapping by the Onyx user group ID."""
return self._session.scalar(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id == user_group_id
)
)
def list_group_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimGroupMapping], int]:
"""List group mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimGroupMapping)
.order_by(ScimGroupMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def delete_group_mapping(self, mapping_id: int) -> None:
"""Delete a group mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group query operations
# ------------------------------------------------------------------
def get_group(self, group_id: int) -> UserGroup | None:
"""Fetch a group by ID, returning None if deleted or missing."""
group = self._session.get(UserGroup, group_id)
if group and group.is_up_for_deletion:
return None
return group
def get_group_by_name(self, name: str) -> UserGroup | None:
"""Fetch a group by exact name."""
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
def add_group(self, group: UserGroup) -> None:
"""Add a new group to the session and flush to assign an ID."""
self._session.add(group)
self._session.flush()
def update_group(
self,
group: UserGroup,
*,
name: str | None = None,
) -> None:
"""Update group attributes and set the modification timestamp."""
if name is not None:
group.name = name
group.time_last_modified_by_user = func.now()
def delete_group(self, group: UserGroup) -> None:
"""Delete a group from the session."""
self._session.delete(group)
def list_groups(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[UserGroup, str | None]], int]:
"""Query groups with optional SCIM filter and pagination.
Returns:
A tuple of (list of (group, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(UserGroup.id == mapping.user_group_id)
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
groups = list(
self._session.scalars(
query.order_by(UserGroup.id).offset(offset).limit(count)
).all()
)
ext_id_map = self._get_group_external_ids([g.id for g in groups])
return [(g, ext_id_map.get(g.id)) for g in groups], total
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
"""Get group members as (user_id, email) pairs."""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
).all()
user_ids = [r.user_id for r in rels if r.user_id]
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users_by_id = {u.id: u for u in users}
return [
(
r.user_id,
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
)
for r in rels
if r.user_id
]
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
"""Return the subset of UUIDs that don't exist as users.
Returns an empty list if all IDs are valid.
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Add user-group relationships, ignoring duplicates."""
if not user_ids:
return
self._session.execute(
pg_insert(User__UserGroup)
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
.on_conflict_do_nothing(
index_elements=[
User__UserGroup.user_group_id,
User__UserGroup.user_id,
]
)
)
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Replace all members of a group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
)
self.upsert_group_members(group_id, user_ids)
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Remove specific members from a group."""
if not user_ids:
return
self._session.execute(
sa_delete(User__UserGroup).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.in_(user_ids),
)
)
def delete_group_with_members(self, group: UserGroup) -> None:
"""Remove all member relationships and delete the group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
)
self._session.delete(group)
def sync_group_external_id(
self, group_id: int, new_external_id: str | None
) -> None:
"""Create, update, or delete the external ID mapping for a group."""
mapping = self.get_group_mapping_by_group_id(group_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_group_mapping(
external_id=new_external_id, user_group_id=group_id
)
elif mapping:
self.delete_group_mapping(mapping.id)
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
"""Batch-fetch external IDs for a list of group IDs."""
if not group_ids:
return {}
mappings = self._session.scalars(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id.in_(group_ids)
)
).all()
return {m.user_group_id: m.external_id for m in mappings}
# ---------------------------------------------------------------------------
# Module-level helpers (used by DAL methods above)
# ---------------------------------------------------------------------------
def _apply_scim_string_op(
query: Select[tuple[User]] | Select[tuple[UserGroup]],
column: SQLColumnExpression[str],
scim_filter: ScimFilter,
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
SQLAlchemy's operators handle LIKE-pattern escaping internally.
"""
val = scim_filter.value
if scim_filter.operator == ScimFilterOperator.EQUAL:
return query.where(func.lower(column) == val.lower())
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
return query.where(column.icontains(val, autoescape=True))
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
return query.where(column.istartswith(val, autoescape=True))
else:
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")

View File

@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -162,6 +163,11 @@ def get_application() -> FastAPI:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -5,6 +5,11 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
# before bearer token configuration is complete
("/scim/v2/ServiceProviderConfig", {"GET"}),
("/scim/v2/ResourceTypes", {"GET"}),
("/scim/v2/Schemas", {"GET"}),
# needs to be accessible prior to user login
("/enterprise-settings", {"GET"}),
("/enterprise-settings/logo", {"GET"}),

View File

@@ -13,6 +13,7 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
@@ -22,6 +23,10 @@ from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
@@ -198,3 +203,63 @@ def upload_custom_analytics_script(
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()
# ---------------------------------------------------------------------------
# SCIM token management
# ---------------------------------------------------------------------------
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
return ScimTokenResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
)
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.
Only one token is active at a time — creating a new token automatically
revokes all previous tokens. The raw token value is returned exactly once
in the response; it cannot be retrieved again.
"""
raw_token, hashed_token, token_display = generate_scim_token()
token = dal.create_token(
name=body.name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=user.id,
)
dal.commit()
return ScimTokenCreatedResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
raw_token=raw_token,
)

View File

@@ -0,0 +1,689 @@
"""SCIM 2.0 API endpoints (RFC 7644).
This module provides the FastAPI router for SCIM service discovery,
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
these endpoints to provision and manage users and groups.
Service discovery endpoints are unauthenticated — IdPs may probe them
before bearer token configuration is complete. All other endpoints
require a valid SCIM bearer token.
"""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@scim_router.get("/ServiceProviderConfig")
def get_service_provider_config() -> ScimServiceProviderConfig:
"""Advertise supported SCIM features (RFC 7643 §5)."""
return SERVICE_PROVIDER_CONFIG
@scim_router.get("/ResourceTypes")
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
@scim_router.get("/Schemas")
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> JSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
body = ScimError(status=str(status), detail=detail)
return JSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
except ValueError:
return _scim_error_response(404, f"User {user_id} not found")
user = dal.get_user(uid)
if not user:
return _scim_error_response(404, f"User {user_id} not found")
return user
def _scim_name_to_str(name: ScimName | None) -> str | None:
"""Extract a display name string from a SCIM name object.
Returns None if no name is provided, so the caller can decide
whether to update the user's personal_name.
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# ---------------------------------------------------------------------------
# User CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
)
try:
dal.add_user(user)
except IntegrityError:
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
dal.commit()
return _user_to_scim(user, external_id)
@scim_router.put("/Users/{user_id}", response_model=None)
def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
# Handle activation (need seat check) / deactivation
if user_resource.active and not user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
dal.commit()
return _user_to_scim(user, new_external_id)
@scim_router.patch("/Users/{user_id}", response_model=None)
def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
``PATCH {"active": false}`` rather than DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current = _user_to_scim(user, external_id)
try:
patched = apply_user_patch(patch_request.Operations, current)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.commit()
return _user_to_scim(user, patched.externalId)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
dal.deactivate_user(user)
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
dal.commit()
return Response(status_code=204)
# ---------------------------------------------------------------------------
# Group helpers
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
except ValueError:
return _scim_error_response(404, f"Group {group_id} not found")
group = dal.get_group(gid)
if not group:
return _scim_error_response(404, f"Group {group_id} not found")
return group
def _parse_member_uuids(
members: list[ScimGroupMember],
) -> tuple[list[UUID], str | None]:
"""Parse member value strings to UUIDs.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids: list[UUID] = []
for m in members:
try:
uuids.append(UUID(m.value))
except ValueError:
return [], f"Invalid member ID: {m.value}"
return uuids, None
def _validate_and_parse_members(
members: list[ScimGroupMember], dal: ScimDAL
) -> tuple[list[UUID], str | None]:
"""Parse and validate member UUIDs exist in the database.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids, err = _parse_member_uuids(members)
if err:
return [], err
if uuids:
missing = dal.validate_member_ids(uuids)
if missing:
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
return uuids, None
# ---------------------------------------------------------------------------
# Group CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
db_group = UserGroup(
name=group_resource.displayName,
is_up_to_date=True,
time_last_modified_by_user=func.now(),
)
try:
dal.add_group(db_group)
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
dal.upsert_group_members(db_group.id, member_uuids)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
dal.update_group(group, name=new_name)
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
missing = dal.validate_member_ids(add_uuids)
if missing:
return _scim_error_response(
400,
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
dal.commit()
return Response(status_code=204)
def _is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
try:
UUID(value)
return True
except ValueError:
return False

View File

@@ -0,0 +1,104 @@
"""SCIM bearer token authentication.
SCIM endpoints are authenticated via bearer tokens that admins create in the
Onyx UI. This module provides:
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
validates the token from the Authorization header.
- ``generate_scim_token``: Creates a new cryptographically random token
and returns the raw value, its SHA-256 hash, and a display suffix.
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
URL-safe base64 from ``secrets.token_urlsafe``.
The hash is stored in the ``scim_token`` table; the raw value is shown to
the admin exactly once at creation time.
"""
import hashlib
import secrets
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
def _hash_scim_token(token: str) -> str:
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def generate_scim_token() -> tuple[str, str, str]:
"""Generate a new SCIM bearer token.
Returns:
A tuple of ``(raw_token, hashed_token, token_display)`` where
``token_display`` is a masked version showing only the last 4 chars.
"""
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
hashed_token = _hash_scim_token(raw_token)
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
return raw_token, hashed_token, token_display
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
"""Extract and hash a SCIM token from the request Authorization header."""
return get_hashed_bearer_token_from_request(
request,
valid_prefixes=[SCIM_TOKEN_PREFIX],
hash_fn=_hash_scim_token,
)
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
def verify_scim_token(
request: Request,
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimToken:
"""FastAPI dependency that authenticates SCIM requests.
Extracts the bearer token from the Authorization header, hashes it,
looks it up in the database, and verifies it is active.
Note:
This dependency does NOT update ``last_used_at`` — the endpoint
should do that via ``ScimDAL.update_token_last_used()`` so the
timestamp write is part of the endpoint's transaction.
Raises:
HTTPException(401): If the token is missing, invalid, or inactive.
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise HTTPException(
status_code=401,
detail="Missing or invalid SCIM bearer token",
)
token = dal.get_token_by_hash(hashed)
if not token:
raise HTTPException(
status_code=401,
detail="Invalid SCIM bearer token",
)
if not token.is_active:
raise HTTPException(
status_code=401,
detail="SCIM token has been revoked",
)
return token

View File

@@ -30,6 +30,7 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
# ---------------------------------------------------------------------------
@@ -195,10 +196,39 @@ class ScimServiceProviderConfig(BaseModel):
)
class ScimSchemaAttribute(BaseModel):
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
name: str
type: str
multiValued: bool = False
required: bool = False
description: str = ""
caseExact: bool = False
mutability: str = "readWrite"
returned: str = "default"
uniqueness: str = "none"
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
class ScimSchemaDefinition(BaseModel):
"""SCIM Schema definition (RFC 7643 §7).
Served at GET /scim/v2/Schemas. Describes the attributes available
on each resource type so IdPs know which fields they can provision.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
id: str
name: str
description: str
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True)
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
schema_: str = Field(alias="schema")
required: bool
@@ -211,7 +241,7 @@ class ScimResourceType(BaseModel):
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True)
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str

View File

@@ -0,0 +1,144 @@
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaAttribute
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "User",
"name": "User",
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
}
)
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "Group",
"name": "Group",
"endpoint": "/scim/v2/Groups",
"description": "SCIM Group resource",
"schema": SCIM_GROUP_SCHEMA,
}
)
USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_USER_SCHEMA,
name="User",
description="SCIM core User schema",
attributes=[
ScimSchemaAttribute(
name="userName",
type="string",
required=True,
uniqueness="server",
description="Unique identifier for the user, typically an email address.",
),
ScimSchemaAttribute(
name="name",
type="complex",
description="The components of the user's name.",
subAttributes=[
ScimSchemaAttribute(
name="givenName",
type="string",
description="The user's first name.",
),
ScimSchemaAttribute(
name="familyName",
type="string",
description="The user's last name.",
),
ScimSchemaAttribute(
name="formatted",
type="string",
description="The full name, including all middle names and titles.",
),
],
),
ScimSchemaAttribute(
name="emails",
type="complex",
multiValued=True,
description="Email addresses for the user.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Email address value.",
),
ScimSchemaAttribute(
name="type",
type="string",
description="Label for this email (e.g. 'work').",
),
ScimSchemaAttribute(
name="primary",
type="boolean",
description="Whether this is the primary email.",
),
],
),
ScimSchemaAttribute(
name="active",
type="boolean",
description="Whether the user account is active.",
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",
description="SCIM core Group schema",
attributes=[
ScimSchemaAttribute(
name="displayName",
type="string",
required=True,
description="Human-readable name for the group.",
),
ScimSchemaAttribute(
name="members",
type="complex",
multiValued=True,
description="Members of the group.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="User ID of the group member.",
),
ScimSchemaAttribute(
name="display",
type="string",
mutability="readOnly",
description="Display name of the group member.",
),
],
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)

View File

@@ -1,10 +1,13 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -41,6 +44,14 @@ def check_ee_features_enabled() -> bool:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss — warm from DB so cold-start doesn't block EE features
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(f"Failed to load license from DB: {db_error}")
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
@@ -82,6 +93,18 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss (e.g. after TTL expiry). Fall back to DB so
# the /settings request doesn't falsely return GATED_ACCESS
# while the cache is cold.
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(
f"Failed to load license from DB for settings: {db_error}"
)
if metadata:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
@@ -90,7 +113,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license found.
# No license found in cache or DB.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.

View File

@@ -121,6 +121,7 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -137,6 +138,8 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -208,22 +211,34 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
whitelist = get_invited_users()
if not whitelist:
if not workspace_invite_only_enabled():
return
whitelist = get_invited_users()
if not email:
raise PermissionError("Email must be specified")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
for email_whitelist in whitelist:
try:
@@ -240,7 +255,13 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:

View File

@@ -1,10 +0,0 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -41,3 +41,14 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,6 +8,12 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -47,7 +53,13 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
# shared_task allows this task to be shared across celery app instances.
@@ -76,11 +88,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token stored in the
tracked via a continuation token map stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -153,15 +169,28 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token,
continuation_token_map,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if continuation_token is None and total_chunks_migrated > 0:
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -170,19 +199,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token: {continuation_token}"
f"Continuation token map: {continuation_token_map}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token = (
raw_vespa_chunks, next_continuation_token_map = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token=continuation_token,
continuation_token_map=continuation_token_map,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token: {next_continuation_token}"
f"seconds. Next continuation token map: {next_continuation_token_map}"
)
opensearch_document_chunks, errored_chunks = (
@@ -212,14 +241,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token=next_continuation_token,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -37,6 +37,35 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.

View File

@@ -1,8 +0,0 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -57,6 +57,7 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -68,6 +69,18 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
"""Detect XML-style marshaled tool calls emitted as plain text."""
if not text:
return False
lowered = text.lower()
return (
"<function_calls" in lowered
and "<invoke" in lowered
and "<parameter" in lowered
)
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
@@ -122,38 +135,56 @@ def _try_fallback_tool_extraction(
reasoning_but_no_answer_or_tools = (
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
)
xml_tool_call_text_detected = no_tool_calls and (
_looks_like_xml_tool_call_payload(llm_step_result.answer)
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
)
should_try_fallback = (
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
) or reasoning_but_no_answer_or_tools
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
or reasoning_but_no_answer_or_tools
or xml_tool_call_text_detected
)
if not should_try_fallback:
return llm_step_result, False
# Try to extract from answer first, then fall back to reasoning
extracted_tool_calls: list[ToolCallKickoff] = []
if llm_step_result.answer:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if (
not extracted_tool_calls
and llm_step_result.raw_answer
and llm_step_result.raw_answer != llm_step_result.answer
):
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.raw_answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if not extracted_tool_calls and llm_step_result.reasoning:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.reasoning,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if extracted_tool_calls:
logger.info(
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
"as fallback"
)
return (
LlmStepResult(
reasoning=llm_step_result.reasoning,
answer=llm_step_result.answer,
tool_calls=extracted_tool_calls,
raw_answer=llm_step_result.raw_answer,
),
True,
)
@@ -451,7 +482,42 @@ def construct_message_history(
if reminder_message:
result.append(reminder_message)
return result
return _drop_orphaned_tool_call_responses(result)
def _drop_orphaned_tool_call_responses(
messages: list[ChatMessageSimple],
) -> list[ChatMessageSimple]:
"""Drop tool response messages whose tool_call_id is not in prior assistant tool calls.
This can happen when history truncation drops an ASSISTANT tool-call message but
leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama)
reject such history with an "unexpected tool call id" error.
"""
known_tool_call_ids: set[str] = set()
sanitized: list[ChatMessageSimple] = []
for msg in messages:
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
for tool_call in msg.tool_calls:
known_tool_call_ids.add(tool_call.tool_call_id)
sanitized.append(msg)
continue
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids:
sanitized.append(msg)
else:
logger.debug(
"Dropping orphaned tool response with tool_call_id=%s while "
"constructing message history",
msg.tool_call_id,
)
continue
sanitized.append(msg)
return sanitized
def _create_file_tool_metadata_message(
@@ -586,6 +652,7 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
@@ -696,6 +763,7 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -835,6 +903,18 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -1,10 +1,12 @@
import json
import re
import time
import uuid
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Sequence
from html import unescape
from typing import Any
from typing import cast
@@ -18,6 +20,7 @@ from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import ChatFileType
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
@@ -56,6 +59,112 @@ from onyx.utils.text_processing import find_all_json_objects
logger = setup_logger()
_XML_INVOKE_BLOCK_RE = re.compile(
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
re.IGNORECASE | re.DOTALL,
)
_XML_PARAMETER_RE = re.compile(
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
re.IGNORECASE | re.DOTALL,
)
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
class _XmlToolCallContentFilter:
"""Streaming filter that strips XML-style tool call payload blocks from text."""
def __init__(self) -> None:
self._pending = ""
self._inside_function_calls_block = False
def process(self, content: str) -> str:
if not content:
return ""
self._pending += content
output_parts: list[str] = []
while self._pending:
pending_lower = self._pending.lower()
if self._inside_function_calls_block:
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
if end_idx == -1:
# Keep buffering until we see the close marker.
return "".join(output_parts)
# Drop the whole function_calls block.
self._pending = self._pending[
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
]
self._inside_function_calls_block = False
continue
start_idx = _find_function_calls_open_marker(pending_lower)
if start_idx == -1:
# Keep only a possible prefix of "<function_calls" in the buffer so
# marker splits across chunks are handled correctly.
tail_len = _matching_open_marker_prefix_len(self._pending)
emit_upto = len(self._pending) - tail_len
if emit_upto > 0:
output_parts.append(self._pending[:emit_upto])
self._pending = self._pending[emit_upto:]
return "".join(output_parts)
if start_idx > 0:
output_parts.append(self._pending[:start_idx])
# Enter block-stripping mode and keep scanning for close marker.
self._pending = self._pending[start_idx:]
self._inside_function_calls_block = True
return "".join(output_parts)
def flush(self) -> str:
if self._inside_function_calls_block:
# Drop any incomplete block at stream end.
self._pending = ""
self._inside_function_calls_block = False
return ""
remaining = self._pending
self._pending = ""
return remaining
def _matching_open_marker_prefix_len(text: str) -> int:
"""Return longest suffix of text that matches prefix of "<function_calls"."""
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
text_lower = text.lower()
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
for candidate_len in range(max_len, 0, -1):
if text_lower.endswith(marker_lower[:candidate_len]):
return candidate_len
return 0
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
return char is None or char in {">", " ", "\t", "\n", "\r"}
def _find_function_calls_open_marker(text_lower: str) -> int:
"""Find '<function_calls' with a valid tag boundary follower."""
search_from = 0
while True:
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
if idx == -1:
return -1
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
if _is_valid_function_calls_open_follower(follower):
return idx
search_from = idx + 1
def _sanitize_llm_output(value: str) -> str:
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
@@ -272,14 +381,7 @@ def _extract_tool_call_kickoffs(
tab_index_calculated = 0
for tool_call_data in id_to_tool_call_map.values():
if tool_call_data.get("id") and tool_call_data.get("name"):
try:
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
except json.JSONDecodeError:
# If parsing fails, try empty dict, most tools would fail though
logger.error(
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
)
tool_args = {}
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
tool_calls.append(
ToolCallKickoff(
@@ -307,8 +409,9 @@ def extract_tool_calls_from_response_text(
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
This is a fallback mechanism for when the LLM was expected to return tool calls
but didn't use the proper tool call format. It searches for JSON objects in the
response text that match the structure of available tools.
but didn't use the proper tool call format. It searches for tool calls embedded
in response text (JSON first, then XML-like invoke blocks) that match available
tool definitions.
Args:
response_text: The LLM's text response to search for tool calls
@@ -333,10 +436,9 @@ def extract_tool_calls_from_response_text(
if not tool_name_to_def:
return []
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
# Find all JSON objects in the response text
json_objects = find_all_json_objects(response_text)
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
prev_json_obj: dict[str, Any] | None = None
prev_tool_call: tuple[str, dict[str, Any]] | None = None
@@ -364,6 +466,14 @@ def extract_tool_calls_from_response_text(
prev_json_obj = json_obj
prev_tool_call = matched_tool_call
# Some providers/models emit XML-style function calls instead of JSON objects.
# Keep this as a fallback behind JSON extraction to preserve current behavior.
if not matched_tool_calls:
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
response_text=response_text,
tool_name_to_def=tool_name_to_def,
)
tool_calls: list[ToolCallKickoff] = []
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
tool_calls.append(
@@ -386,6 +496,88 @@ def extract_tool_calls_from_response_text(
return tool_calls
def _extract_xml_tool_calls_from_response_text(
response_text: str,
tool_name_to_def: dict[str, dict],
) -> list[tuple[str, dict[str, Any]]]:
"""Extract XML-style tool calls from response text.
Supports formats such as:
<function_calls>
<invoke name="internal_search">
<parameter name="queries" string="false">["foo"]</parameter>
</invoke>
</function_calls>
"""
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
invoke_attrs = invoke_match.group("attrs")
tool_name = _extract_xml_attribute(invoke_attrs, "name")
if not tool_name or tool_name not in tool_name_to_def:
continue
tool_args: dict[str, Any] = {}
invoke_body = invoke_match.group("body")
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
parameter_attrs = parameter_match.group("attrs")
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
if not parameter_name:
continue
string_attr = _extract_xml_attribute(parameter_attrs, "string")
tool_args[parameter_name] = _parse_xml_parameter_value(
raw_value=parameter_match.group("value"),
string_attr=string_attr,
)
matched_tool_calls.append((tool_name, tool_args))
return matched_tool_calls
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
"""Extract a single XML-style attribute value from a tag attribute string."""
attr_match = re.search(
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
attrs,
flags=re.IGNORECASE | re.DOTALL,
)
if not attr_match:
return None
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
"""Parse a parameter value from XML-style tool call payloads."""
value = _sanitize_llm_output(unescape(raw_value).strip())
if string_attr and string_attr.lower() == "true":
return value
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
"""Extract and parse an arguments/parameters value from a tool-call-like object.
Looks for "arguments" or "parameters" keys, handles JSON-string values,
and returns a dict if successful, or None otherwise.
"""
arguments = obj.get("arguments", obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return arguments
return None
def _try_match_json_to_tool(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
@@ -408,13 +600,8 @@ def _try_match_json_to_tool(
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
tool_name = json_obj["name"]
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
arguments = _resolve_tool_arguments(json_obj)
if arguments is not None:
return (tool_name, arguments)
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
@@ -422,13 +609,8 @@ def _try_match_json_to_tool(
func_obj = json_obj["function"]
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
tool_name = func_obj["name"]
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
arguments = _resolve_tool_arguments(func_obj)
if arguments is not None:
return (tool_name, arguments)
# Format 3: Tool name as key {"tool_name": {...arguments...}}
@@ -495,6 +677,107 @@ def _extract_nested_arguments_obj(
return None
def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMessage:
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
tool_calls_list = [
ToolCall(
id=tc.tool_call_id,
type="function",
function=FunctionCall(
name=tc.tool_name,
arguments=json.dumps(tc.tool_arguments),
),
)
for tc in msg.tool_calls
]
return AssistantMessage(
role="assistant",
content=msg.message or None,
tool_calls=tool_calls_list,
)
def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage:
if not msg.tool_call_id:
raise ValueError(
"Tool call response message encountered but tool_call_id is not available. "
f"Message: {msg}"
)
return ToolMessage(
role="tool",
content=msg.message,
tool_call_id=msg.tool_call_id,
)
class _HistoryMessageFormatter:
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
raise NotImplementedError
def format_tool_response_message(
self, msg: ChatMessageSimple
) -> ToolMessage | UserMessage:
raise NotImplementedError
class _DefaultHistoryMessageFormatter(_HistoryMessageFormatter):
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
return _build_structured_assistant_message(msg)
def format_tool_response_message(self, msg: ChatMessageSimple) -> ToolMessage:
return _build_structured_tool_response_message(msg)
class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
if not msg.tool_calls:
return _build_structured_assistant_message(msg)
tool_call_lines = [
(
f"[Tool Call] name={tc.tool_name} "
f"id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
)
for tc in msg.tool_calls
]
assistant_content = (
"\n".join([msg.message, *tool_call_lines])
if msg.message
else "\n".join(tool_call_lines)
)
return AssistantMessage(
role="assistant",
content=assistant_content,
tool_calls=None,
)
def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage:
if not msg.tool_call_id:
raise ValueError(
"Tool call response message encountered but tool_call_id is not available. "
f"Message: {msg}"
)
return UserMessage(
role="user",
content=f"[Tool Result] id={msg.tool_call_id}\n{msg.message}",
)
_DEFAULT_HISTORY_MESSAGE_FORMATTER = _DefaultHistoryMessageFormatter()
_OLLAMA_HISTORY_MESSAGE_FORMATTER = _OllamaHistoryMessageFormatter()
def _get_history_message_formatter(llm_config: LLMConfig) -> _HistoryMessageFormatter:
if llm_config.model_provider == LlmProviderNames.OLLAMA_CHAT:
return _OLLAMA_HISTORY_MESSAGE_FORMATTER
return _DEFAULT_HISTORY_MESSAGE_FORMATTER
def translate_history_to_llm_format(
history: list[ChatMessageSimple],
llm_config: LLMConfig,
@@ -505,6 +788,10 @@ def translate_history_to_llm_format(
handling different message types and image files for multimodal support.
"""
messages: list[ChatCompletionMessage] = []
history_message_formatter = _get_history_message_formatter(llm_config)
# Note: cacheability is computed from pre-translation ChatMessageSimple types.
# Some providers flatten tool history into plain assistant/user text, so this split
# may be less semantically meaningful, but it remains safe and order-preserving.
last_cacheable_msg_idx = -1
all_previous_msgs_cacheable = True
@@ -586,39 +873,10 @@ def translate_history_to_llm_format(
messages.append(reminder_msg)
elif msg.message_type == MessageType.ASSISTANT:
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
tool_calls_list = [
ToolCall(
id=tc.tool_call_id,
type="function",
function=FunctionCall(
name=tc.tool_name,
arguments=json.dumps(tc.tool_arguments),
),
)
for tc in msg.tool_calls
]
assistant_msg = AssistantMessage(
role="assistant",
content=msg.message or None,
tool_calls=tool_calls_list,
)
messages.append(assistant_msg)
messages.append(history_message_formatter.format_assistant_message(msg))
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
if not msg.tool_call_id:
raise ValueError(
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
)
tool_msg = ToolMessage(
role="tool",
content=msg.message,
tool_call_id=msg.tool_call_id,
)
messages.append(tool_msg)
messages.append(history_message_formatter.format_tool_response_message(msg))
else:
logger.warning(
@@ -698,7 +956,8 @@ def run_llm_step_pkt_generator(
tool_definitions: List of tool definitions available to the LLM.
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
llm: Language model interface to use for generation.
turn_index: Current turn index in the conversation.
placement: Placement info (turn_index, tab_index, sub_turn_index) for
positioning packets in the conversation UI.
state_container: Container for storing chat state (reasoning, answers).
citation_processor: Optional processor for extracting and formatting citations
from the response. If provided, processes tokens to identify citations.
@@ -710,7 +969,14 @@ def run_llm_step_pkt_generator(
custom_token_processor: Optional callable that processes each token delta
before yielding. Receives (delta, processor_state) and returns
(modified_delta, new_processor_state). Can return None for delta to skip.
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
max_tokens: Optional maximum number of tokens for the LLM response.
use_existing_tab_index: If True, use the tab_index from placement for all
tool calls instead of auto-incrementing.
is_deep_research: If True, treat content before tool calls as reasoning
when tool_choice is REQUIRED.
pre_answer_processing_time: Optional time spent processing before the
answer started, recorded in state_container for analytics.
timeout_override: Optional timeout override for the LLM call.
Yields:
Packet: Streaming packets containing:
@@ -736,8 +1002,15 @@ def run_llm_step_pkt_generator(
tab_index = placement.tab_index
sub_turn_index = placement.sub_turn_index
def _current_placement() -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
llm_msg_history = translate_history_to_llm_format(history, llm.config)
has_reasoned = 0
has_reasoned = False
if LOG_ONYX_MODEL_INTERACTIONS:
logger.debug(
@@ -749,6 +1022,8 @@ def run_llm_step_pkt_generator(
answer_start = False
accumulated_reasoning = ""
accumulated_answer = ""
accumulated_raw_answer = ""
xml_tool_call_content_filter = _XmlToolCallContentFilter()
processor_state: Any = None
@@ -764,6 +1039,112 @@ def run_llm_step_pkt_generator(
)
stream_start_time = time.monotonic()
first_action_recorded = False
def _emit_citation_results(
results: Generator[str | CitationInfo, None, None],
) -> Generator[Packet, None, None]:
"""Yield packets for citation processor results (str or CitationInfo)."""
nonlocal accumulated_answer
for result in results:
if isinstance(result, str):
accumulated_answer += result
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=_current_placement(),
obj=result,
)
if state_container:
state_container.add_emitted_citation(result.citation_number)
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
nonlocal reasoning_start
nonlocal has_reasoned
nonlocal turn_index
nonlocal sub_turn_index
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = True
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
nonlocal accumulated_answer
nonlocal accumulated_reasoning
nonlocal answer_start
nonlocal reasoning_start
nonlocal turn_index
nonlocal sub_turn_index
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
accumulated_reasoning += content_chunk
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
obj=ReasoningDelta(reasoning=content_chunk),
)
reasoning_start = True
return
# Normal flow for AUTO or NONE tool choice
yield from _close_reasoning_if_active()
if not answer_start:
# Store pre-answer processing time in state container for save_chat
if state_container and pre_answer_processing_time is not None:
state_container.set_pre_answer_processing_time(
pre_answer_processing_time
)
yield Packet(
placement=_current_placement(),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
),
)
answer_start = True
if citation_processor:
yield from _emit_citation_results(
citation_processor.process_token(content_chunk)
)
else:
accumulated_answer += content_chunk
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=content_chunk),
)
for packet in llm.stream(
prompt=llm_msg_history,
tools=tool_definitions,
@@ -822,152 +1203,34 @@ def run_llm_step_pkt_generator(
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
placement=_current_placement(),
obj=ReasoningStart(),
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
placement=_current_placement(),
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
reasoning_start = True
if delta.content:
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
# Treat content as reasoning when we know tool calls are coming
accumulated_reasoning += delta.content
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.content),
)
reasoning_start = True
else:
# Normal flow for AUTO or NONE tool choice
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
if not answer_start:
# Store pre-answer processing time in state container for save_chat
if state_container and pre_answer_processing_time is not None:
state_container.set_pre_answer_processing_time(
pre_answer_processing_time
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
),
)
answer_start = True
if citation_processor:
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(
accumulated_answer
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(
result.citation_number
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=delta.content),
)
# Keep raw content for fallback extraction. Display content can be
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
accumulated_raw_answer += delta.content
filtered_content = xml_tool_call_content_filter.process(delta.content)
if filtered_content:
yield from _emit_content_chunk(filtered_content)
if delta.tool_calls:
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
yield from _close_reasoning_if_active()
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
# Flush any tail text buffered while checking for split "<function_calls" markers.
filtered_content_tail = xml_tool_call_content_filter.flush()
if filtered_content_tail:
yield from _emit_content_chunk(filtered_content_tail)
# Flush custom token processor to get any final tool calls
if custom_token_processor:
flush_delta, processor_state = custom_token_processor(None, processor_state)
@@ -1023,50 +1286,14 @@ def run_llm_step_pkt_generator(
# This may happen if the custom token processor is used to modify other packets into reasoning
# Then there won't necessarily be anything else to come after the reasoning tokens
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
reasoning_start = False
yield from _close_reasoning_if_active()
# Flush any remaining content from citation processor
# Reasoning is always first so this should use the post-incremented value of turn_index
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
# as clickable items and will be stripped out instead.
if citation_processor:
for result in citation_processor.process_token(None):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
yield from _emit_citation_results(citation_processor.process_token(None))
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
@@ -1088,8 +1315,9 @@ def run_llm_step_pkt_generator(
reasoning=accumulated_reasoning if accumulated_reasoning else None,
answer=accumulated_answer if accumulated_answer else None,
tool_calls=tool_calls if tool_calls else None,
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
),
bool(has_reasoned),
has_reasoned,
)
@@ -1144,4 +1372,4 @@ def run_llm_step(
emitter.emit(packet)
except StopIteration as e:
llm_step_result, has_reasoned = e.value
return llm_step_result, bool(has_reasoned)
return llm_step_result, has_reasoned

View File

@@ -185,3 +185,6 @@ class LlmStepResult(BaseModel):
reasoning: str | None
answer: str | None
tool_calls: list[ToolCallKickoff] | None
# Raw LLM text before any display-oriented filtering/sanitization.
# Used for fallback tool-call extraction when providers emit calls as text.
raw_answer: str | None = None

View File

@@ -10,6 +10,7 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -125,6 +126,7 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -132,6 +134,8 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None

View File

@@ -263,6 +263,18 @@ OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance

View File

@@ -46,6 +46,7 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
@@ -156,10 +157,7 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
return False
# For shared drive content, the root has id == driveId
if drive_id and folder_id == drive_id:
return True
return False
return bool(drive_id and folder_id == drive_id)
def _public_access() -> ExternalAccess:
@@ -616,6 +614,16 @@ class GoogleDriveConnector(
# empty parents due to permission limitations)
# Check shared drive root first (simple ID comparison)
if _is_shared_drive_root(folder):
# files().get() returns 'Drive' for shared drive roots;
# fetch the real name via drives().get().
# Try both the retriever and admin since the admin may
# not have access to private shared drives.
drive_name = self._get_shared_drive_name(
current_id, file.user_email
)
if drive_name:
node.display_name = drive_name
node.node_type = HierarchyNodeType.SHARED_DRIVE
reached_terminal = True
break
@@ -691,6 +699,15 @@ class GoogleDriveConnector(
)
return None
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
"""Fetch the name of a shared drive, trying both the retriever and admin."""
for email in {retriever_email, self.primary_admin_email}:
svc = get_drive_service(self.creds, email)
name = get_shared_drive_name(svc, drive_id)
if name:
return name
return None
def get_all_drive_ids(self) -> set[str]:
return self._get_all_drives_for_user(self.primary_admin_email)

View File

@@ -154,6 +154,26 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return HIERARCHY_FIELDS
def get_shared_drive_name(
service: Resource,
drive_id: str,
) -> str | None:
"""Fetch the actual name of a shared drive via the drives().get() API.
The files().get() API returns 'Drive' as the name for shared drive root
folders. Only drives().get() returns the real user-assigned name.
"""
try:
drive = service.drives().get(driveId=drive_id, fields="name").execute()
return drive.get("name")
except HttpError as e:
if e.resp.status in (403, 404):
logger.debug(f"Cannot access drive {drive_id}: {e}")
else:
raise
return None
def get_external_access_for_folder(
folder: GoogleDriveFileType,
google_domain: str,

File diff suppressed because it is too large Load Diff

74
backend/onyx/db/dal.py Normal file
View File

@@ -0,0 +1,74 @@
"""Base Data Access Layer (DAL) for database operations.
The DAL pattern groups related database operations into cohesive classes
with explicit session management. It supports two usage modes:
1. **External session** (FastAPI endpoints) — the caller provides a session
whose lifecycle is managed by FastAPI's dependency injection.
2. **Self-managed session** (Celery tasks, scripts) — the DAL creates its
own session via the tenant-aware session factory.
Subclasses add domain-specific query methods while inheriting session
management. See ``ee.onyx.db.scim.ScimDAL`` for a concrete example.
Example (FastAPI)::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.get("/users")
def list_users(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
return dal.list_user_mappings(...)
Example (Celery)::
with ScimDAL.from_tenant("tenant_abc") as dal:
dal.create_user_mapping(...)
dal.commit()
"""
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_tenant
class DAL:
"""Base Data Access Layer.
Holds a SQLAlchemy session and provides transaction control helpers.
Subclasses add domain-specific query methods.
"""
def __init__(self, db_session: Session) -> None:
self._session = db_session
@property
def session(self) -> Session:
"""Direct access to the underlying session for advanced use cases."""
return self._session
def commit(self) -> None:
self._session.commit()
def flush(self) -> None:
self._session.flush()
def rollback(self) -> None:
self._session.rollback()
@classmethod
@contextmanager
def from_tenant(cls, tenant_id: str) -> Generator["DAL", None, None]:
"""Create a DAL with a self-managed session for the given tenant.
The session is automatically closed when the context manager exits.
The caller must explicitly call ``commit()`` to persist changes.
"""
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield cls(session)

View File

@@ -232,6 +232,12 @@ class BuildSessionStatus(str, PyEnum):
IDLE = "idle"
class SharingScope(str, PyEnum):
PRIVATE = "private"
PUBLIC_ORG = "public_org"
PUBLIC_GLOBAL = "public_global"
class SandboxStatus(str, PyEnum):
PROVISIONING = "provisioning"
RUNNING = "running"

View File

@@ -430,7 +430,7 @@ def fetch_existing_models(
def fetch_existing_llm_providers(
db_session: Session,
flow_types: list[LLMModelFlowType],
flow_type_filter: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -438,30 +438,27 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_types: List of flow types to filter by
flow_type_filter: List of flow types to filter by, empty list for no filter
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
stmt = select(LLMProviderModel)
if flow_type_filter:
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
.distinct()
)
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
@@ -797,13 +794,15 @@ def sync_auto_mode_models(
changes += 1
else:
# Add new model - all models from GitHub config are visible
new_model = ModelConfiguration(
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
name=model_config.name,
display_name=model_config.display_name,
model_name=model_config.name,
supported_flows=[LLMModelFlowType.CHAT],
is_visible=True,
max_input_tokens=None,
display_name=model_config.display_name,
)
db_session.add(new_model)
changes += 1
# In Auto mode, default model is always set from GitHub config

View File

@@ -77,6 +77,7 @@ from onyx.db.enums import (
ThemePreference,
DefaultAppMode,
SwitchoverType,
SharingScope,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
@@ -1040,7 +1041,9 @@ class OpenSearchTenantMigrationRecord(Base):
nullable=False,
)
# Opaque continuation token from Vespa's Visit API.
# NULL means "not started" or "visit completed".
# NULL means "not started".
# Otherwise contains a serialized mapping between slice ID and continuation
# token for that slice.
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
Text, nullable=True
)
@@ -1064,6 +1067,9 @@ class OpenSearchTenantMigrationRecord(Base):
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column(
Integer, nullable=True
)
class KGEntityType(Base):
@@ -4712,6 +4718,12 @@ class BuildSession(Base):
demo_data_enabled: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
sharing_scope: Mapped[SharingScope] = mapped_column(
String,
nullable=False,
default=SharingScope.PRIVATE,
server_default="private",
)
# Relationships
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])

View File

@@ -4,6 +4,7 @@ This module provides functions to track the progress of migrating documents
from Vespa to OpenSearch.
"""
import json
from datetime import datetime
from datetime import timezone
@@ -12,6 +13,9 @@ from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_SLICE_COUNT,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
@@ -243,29 +247,37 @@ def should_document_migration_be_permanently_failed(
def get_vespa_visit_state(
db_session: Session,
) -> tuple[str | None, int]:
) -> tuple[dict[int, str | None], int]:
"""Gets the current Vespa migration state from the tenant migration record.
Requires the OpenSearchTenantMigrationRecord to exist.
Returns:
Tuple of (continuation_token, total_chunks_migrated). continuation_token
is None if not started or completed.
Tuple of (continuation_token_map, total_chunks_migrated).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
return (
record.vespa_visit_continuation_token,
record.total_chunks_migrated,
)
if record.vespa_visit_continuation_token is None:
continuation_token_map: dict[int, str | None] = {
slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT)
}
else:
json_loaded_continuation_token_map = json.loads(
record.vespa_visit_continuation_token
)
continuation_token_map = {
int(key): value for key, value in json_loaded_continuation_token_map.items()
}
return continuation_token_map, record.total_chunks_migrated
def update_vespa_visit_progress_with_commit(
db_session: Session,
continuation_token: str | None,
continuation_token_map: dict[int, str | None],
chunks_processed: int,
chunks_errored: int,
approx_chunk_count_in_vespa: int | None,
) -> None:
"""Updates the Vespa migration progress and commits.
@@ -273,19 +285,26 @@ def update_vespa_visit_progress_with_commit(
Args:
db_session: SQLAlchemy session.
continuation_token: The new continuation token. None means the visit
is complete.
continuation_token_map: The new continuation token map. None entry means
the visit is complete for that slice.
chunks_processed: Number of chunks processed in this batch (added to
the running total).
chunks_errored: Number of chunks errored in this batch (added to the
running errored total).
approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If
None, the existing value is used.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.vespa_visit_continuation_token = continuation_token
record.vespa_visit_continuation_token = json.dumps(continuation_token_map)
record.total_chunks_migrated += chunks_processed
record.total_chunks_errored += chunks_errored
record.approx_chunk_count_in_vespa = (
approx_chunk_count_in_vespa
if approx_chunk_count_in_vespa is not None
else record.approx_chunk_count_in_vespa
)
db_session.commit()
@@ -353,25 +372,27 @@ def build_sanitized_to_original_doc_id_mapping(
def get_opensearch_migration_state(
db_session: Session,
) -> tuple[int, datetime | None, datetime | None]:
) -> tuple[int, datetime | None, datetime | None, int | None]:
"""Returns the state of the Vespa to OpenSearch migration.
If the tenant migration record is not found, returns defaults of 0, None,
None.
None, None.
Args:
db_session: SQLAlchemy session.
Returns:
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
Tuple of (total_chunks_migrated, created_at, migration_completed_at,
approx_chunk_count_in_vespa).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return 0, None, None
return 0, None, None, None
return (
record.total_chunks_migrated,
record.created_at,
record.migration_completed_at,
record.approx_chunk_count_in_vespa,
)

View File

@@ -54,6 +54,9 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
@@ -706,10 +709,12 @@ class OpenSearchClient:
)
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
)
search_hits.append(search_hit)
logger.debug(

View File

@@ -10,31 +10,31 @@ EF_CONSTRUCTION = 256
# quality but increase memory footprint. Values typically range between 12 - 48.
M = 32 # Set relatively high for better accuracy.
# Number of vectors to examine for top k neighbors for the HNSW method.
# Should be >= DEFAULT_K_NUM_CANDIDATES for good recall; higher = better accuracy, slower search.
# Bumped this to 1000, for dataset of low 10,000 docs, did not see improvement in recall.
EF_SEARCH = 256
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
# a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
# The default number of neighbors to consider for knn vector similarity search.
# We need this higher than the number of results because the scoring is hybrid.
# If there is only 1 query, setting k equal to the number of results is enough,
# but since there is heavy reordering due to hybrid scoring, we need to set k higher.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, more query cost.
DEFAULT_K_NUM_CANDIDATES = 50 # TODO likely need to bump this way higher
# Number of vectors to examine for top k neighbors for the HNSW method.
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
# rather than an independent scoring component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_TITLE_KEYWORD_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.4
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.4
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_TITLE_KEYWORD_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_CONTENT_KEYWORD_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
]
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0

View File

@@ -842,6 +842,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights

View File

@@ -11,6 +11,7 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -54,6 +55,11 @@ SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids"
# Faiss was also tried but it didn't have any benefits
# NMSLIB is deprecated, not recommended
OPENSEARCH_KNN_ENGINE = "lucene"
def get_opensearch_doc_chunk_id(
tenant_state: TenantState,
document_id: str,
@@ -343,6 +349,9 @@ class DocumentSchema:
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
# Language analyzer (e.g. english) stems at index and search time for variant matching.
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"fields": {
# Subfield accessed as title.keyword. Not indexed for
# values longer than 256 chars.
@@ -357,9 +366,7 @@ class DocumentSchema:
CONTENT_FIELD_NAME: {
"type": "text",
"store": True,
# This makes highlighting text during queries more efficient
# at the cost of disk space. See
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"index_options": "offsets",
},
TITLE_VECTOR_FIELD_NAME: {
@@ -368,7 +375,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"engine": OPENSEARCH_KNN_ENGINE,
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
@@ -380,7 +387,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"engine": OPENSEARCH_KNN_ENGINE,
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},

View File

@@ -6,13 +6,16 @@ from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_K_NUM_CANDIDATES
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
@@ -240,6 +243,9 @@ class DocumentQuery:
Returns:
A dictionary representing the final hybrid search query.
"""
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
@@ -247,7 +253,7 @@ class DocumentQuery:
)
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, num_candidates=DEFAULT_K_NUM_CANDIDATES
query_text, query_vector
)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
@@ -275,25 +281,31 @@ class DocumentQuery:
hybrid_search_query: dict[str, Any] = {
"hybrid": {
"queries": hybrid_search_subqueries,
# Applied to all the sub-queries. Source:
# Max results per subquery per shard before aggregation. Ensures keyword and vector
# subqueries contribute equally to the candidate pool for hybrid fusion.
# Sources:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
# Sources:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
# Explain is for scoring breakdowns.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
return final_hybrid_search_body
@@ -355,7 +367,12 @@ class DocumentQuery:
@staticmethod
def _get_hybrid_search_subqueries(
query_text: str, query_vector: list[float], num_candidates: int
query_text: str,
query_vector: list[float],
# The default number of neighbors to consider for knn vector similarity search.
# This is higher than the number of results because the scoring is hybrid.
# for a detailed breakdown, see where the default value is set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -367,9 +384,8 @@ class DocumentQuery:
Matches:
- Title vector
- Title keyword
- Content vector
- Content keyword + phrase
- Keyword (title + content, match and phrase)
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
@@ -390,9 +406,9 @@ class DocumentQuery:
NOTE: Options considered and rejected:
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
and very low number of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). This is reasonable but in reality seeing the
user usage patterns, this is not very common and people tend to not be confused when a miss happens for this reason.
In testing datasets, this makes recall slightly worse.
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
less performant so not really any reason to do it.
Args:
query_text: The text of the query to search for.
@@ -401,64 +417,56 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, title keyword, content vector,
# content keyword.
# pipeline weights: title vector, content vector, keyword (title + content).
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
"k": vector_candidates,
}
}
},
# 2. Title keyword + phrase search.
{
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}
},
]
}
},
# 3. Content vector search
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
"k": vector_candidates,
}
}
},
# 4. Content keyword + phrase search.
# 3. Keyword (title + content) match and phrase search.
{
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 0.2,
}
}
},
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
"boost": 1.0,
}
}
},
@@ -466,9 +474,7 @@ class DocumentQuery:
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}

View File

@@ -10,6 +10,12 @@ from typing import cast
import httpx
from retry import retry
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.transformer import (
FIELDS_NEEDED_FOR_TRANSFORMATION,
)
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.context.search.models import IndexFilters
@@ -277,54 +283,139 @@ def get_chunks_via_visit_api(
def get_all_chunks_paginated(
index_name: str,
tenant_state: TenantState,
continuation_token: str | None = None,
page_size: int = 1_000,
) -> tuple[list[dict], str | None]:
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict], dict[int, str | None]]:
"""Gets all chunks in Vespa matching the filters, paginated.
Uses the Visit API with slicing. Each continuation token map entry is for a
different slice. The number of entries determines the number of slices.
Args:
index_name: The name of the Vespa index to visit.
tenant_state: The tenant state to filter by.
continuation_token: Token returned by Vespa representing a page offset.
None to start from the beginning. Defaults to None.
continuation_token_map: Map of slice ID to a token returned by Vespa
representing a page offset. None to start from the beginning of the
slice.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
def _get_all_chunks_paginated_for_slice(
index_name: str,
tenant_state: TenantState,
slice_id: int,
total_slices: int,
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict], str | None]:
if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}."
)
return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
params: dict[str, str | int | None] = {
"selection": selection,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
}
if continuation_token is not None:
params["continuation"] = continuation_token
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to get chunks in Vespa."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION)
params: dict[str, str | int | None] = {
"selection": selection,
"fieldSet": field_set,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
"slices": total_slices,
"sliceId": slice_id,
}
if continuation_token is not None:
params["continuation"] = continuation_token
response: httpx.Response | None = None
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
error_message = (
response.json().get("message") if response else "No response"
)
logger.error("Error message from response: %s", error_message)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
# NOTE: If we see a falsey value for "continuation" in the response we
# assume we are done and return
# FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead.
next_continuation_token = (
response_data.get("continuation")
or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
)
raise httpx.HTTPError(error_base) from e
chunks = [chunk["fields"] for chunk in response_data.get("documents", [])]
if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}."
)
return chunks, next_continuation_token
response_data = response.json()
total_slices = len(continuation_token_map)
if total_slices < 1:
raise ValueError("continuation_token_map must have at least one entry.")
# We want to guarantee that these invocations are ordered by slice_id,
# because we read in the same order below when parsing parallel_results.
functions_with_args: list[tuple[Callable, tuple]] = [
(
_get_all_chunks_paginated_for_slice,
(
index_name,
tenant_state,
slice_id,
total_slices,
continuation_token,
page_size,
),
)
for slice_id, continuation_token in sorted(continuation_token_map.items())
]
return [
chunk["fields"] for chunk in response_data.get("documents", [])
], response_data.get("continuation") or None
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
if len(parallel_results) != total_slices:
raise RuntimeError(
f"Expected {total_slices} parallel results, but got {len(parallel_results)}."
)
chunks: list[dict] = []
next_continuation_token_map: dict[int, str | None] = {
key: value for key, value in continuation_token_map.items()
}
for i, parallel_result in enumerate(parallel_results):
if i not in next_continuation_token_map:
raise RuntimeError(f"Slice {i} is not in the continuation token map.")
if parallel_result is None:
logger.error(
f"Failed to get chunks for slice {i} of {total_slices}. "
"The continuation token for this slice will not be updated."
)
continue
chunks.extend(parallel_result[0])
next_continuation_token_map[i] = parallel_result[1]
return chunks, next_continuation_token_map
# TODO(rkuo): candidate for removal if not being used

View File

@@ -56,6 +56,7 @@ from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -652,9 +653,9 @@ class VespaDocumentIndex(DocumentIndex):
def get_all_raw_document_chunks_paginated(
self,
continuation_token: str | None,
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict[str, Any]], str | None]:
) -> tuple[list[dict[str, Any]], dict[int, str | None]]:
"""Gets all the chunks in Vespa, paginated.
Used in the chunk-level Vespa-to-OpenSearch migration task.
@@ -662,21 +663,21 @@ class VespaDocumentIndex(DocumentIndex):
Args:
continuation_token: Token returned by Vespa representing a page
offset. None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
page_size: Best-effort batch size for the visit.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
raw_chunks, next_continuation_token = get_all_chunks_paginated(
raw_chunks, next_continuation_token_map = get_all_chunks_paginated(
index_name=self._index_name,
tenant_state=TenantState(
tenant_id=self._tenant_id, multitenant=MULTI_TENANT
),
continuation_token=continuation_token,
continuation_token_map=continuation_token_map,
page_size=page_size,
)
return raw_chunks, next_continuation_token
return raw_chunks, next_continuation_token_map
def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None:
"""Indexes raw document chunks into Vespa.
@@ -702,3 +703,32 @@ class VespaDocumentIndex(DocumentIndex):
json={"fields": chunk},
)
response.raise_for_status()
def get_chunk_count(self) -> int:
"""Returns the exact number of document chunks in Vespa for this tenant.
Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked`
to get an exact count without fetching any document data.
Includes large chunks. There is no way to filter these out using the
Search API.
"""
where_clause = (
f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true"
)
yql = (
f"select documentid from {self._index_name} "
f"where {where_clause} "
f"limit 0"
)
params: dict[str, str | int] = {
"yql": yql,
"ranking.profile": "unranked",
"timeout": VESPA_TIMEOUT,
}
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
response_data = response.json()
return response_data["root"]["fields"]["totalCount"]

View File

@@ -215,6 +215,23 @@ def model_configurations_for_provider(
) -> list[ModelConfigurationView]:
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
recommended_visible_models_names = [m.name for m in recommended_visible_models]
# Preserve provider-defined ordering while de-duplicating.
model_names: list[str] = []
seen_model_names: set[str] = set()
for model_name in (
fetch_models_for_provider(provider_name) + recommended_visible_models_names
):
if model_name in seen_model_names:
continue
seen_model_names.add(model_name)
model_names.append(model_name)
# Vertex model list can be large and mixed-vendor; alphabetical ordering
# makes model discovery easier in admin selection UIs.
if provider_name == VERTEXAI_PROVIDER_NAME:
model_names = sorted(model_names, key=str.lower)
return [
ModelConfigurationView(
name=model_name,
@@ -222,8 +239,7 @@ def model_configurations_for_provider(
max_input_tokens=get_max_input_tokens(model_name, provider_name),
supports_image_input=model_supports_image_input(model_name, provider_name),
)
for model_name in set(fetch_models_for_provider(provider_name))
| set(recommended_visible_models_names)
for model_name in model_names
]

View File

@@ -21,7 +21,6 @@ from fastapi.routing import APIRoute
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from starlette.types import Lifespan
@@ -53,6 +52,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.connection_warmup import warm_up_connections
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
@@ -64,7 +64,7 @@ from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.build.api.api import nextjs_assets_router
from onyx.server.features.build.api.api import public_build_router
from onyx.server.features.build.api.api import router as build_router
from onyx.server.features.default_assistant.api import (
router as default_assistant_router,
@@ -115,6 +115,10 @@ from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.metrics.postgres_connection_pool import (
setup_postgres_connection_pool_metrics,
)
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
@@ -138,6 +142,7 @@ from onyx.setup import setup_onyx
from onyx.tracing.setup import setup_tracing
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_endpoint_context_middleware
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
@@ -266,6 +271,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
# Register pool metrics now that engines are created.
# HTTP instrumentation is set up earlier in get_application() since it
# adds middleware (which Starlette forbids after the app has started).
setup_postgres_connection_pool_metrics(
engines={
"sync": SqlEngine.get_engine(),
"async": get_sqlalchemy_async_engine(),
"readonly": SqlEngine.get_readonly_engine(),
},
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)
@@ -378,8 +394,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
@@ -560,12 +576,18 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
add_onyx_request_id_middleware(application, "API", logger)
# Set endpoint context for per-endpoint DB pool attribution metrics.
# Must be registered after all routes are added.
add_endpoint_context_middleware(application)
# HTTP request metrics (latency histograms, in-progress gauge, slow request
# counter). Must be called here — before the app starts — because the
# instrumentator adds middleware via app.add_middleware().
setup_prometheus_metrics(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
use_route_function_names_as_operation_ids(application)
return application

View File

@@ -69,6 +69,12 @@ Very briefly describe the image(s) generated. Do not include any links or attach
""".strip()
FILE_REMINDER = """
Your code execution generated file(s) with download links.
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
""".strip()
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "

View File

@@ -59,6 +59,9 @@ PUBLIC_ENDPOINT_SPECS = [
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
# craft webapp proxy — access enforced per-session via sharing_scope in handler
("/build/sessions/{session_id}/webapp", {"GET"}),
("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}),
]
@@ -102,6 +105,9 @@ def check_router_auth(
current_cloud_superuser = fetch_ee_implementation_or_noop(
"onyx.auth.users", "current_cloud_superuser"
)
verify_scim_token = fetch_ee_implementation_or_noop(
"onyx.server.scim.auth", "verify_scim_token"
)
for route in application.routes:
# explicitly marked as public
@@ -125,6 +131,7 @@ def check_router_auth(
or depends_fn == current_chat_accessible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
or depends_fn == verify_scim_token
):
found_auth = True
break

View File

@@ -1,4 +1,5 @@
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
import httpx
@@ -7,16 +8,19 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import ProcessingMode
from onyx.db.enums import SharingScope
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -217,12 +221,15 @@ def get_build_connectors(
return BuildConnectorListResponse(connectors=connectors)
# Headers to skip when proxying (hop-by-hop headers)
# Headers to skip when proxying.
# Hop-by-hop headers must not be forwarded, and set-cookie is stripped to
# prevent LLM-generated apps from setting cookies on the parent Onyx domain.
EXCLUDED_HEADERS = {
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
"set-cookie",
}
@@ -280,7 +287,7 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
db_session: Database session
Returns:
The internal URL to proxy requests to
Internal URL to proxy requests to
Raises:
HTTPException: If session not found, port not allocated, or sandbox not found
@@ -294,12 +301,10 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
if session.user_id is None:
raise HTTPException(status_code=404, detail="User not found")
# Get the user's sandbox to get the sandbox_id
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
if sandbox is None:
raise HTTPException(status_code=404, detail="Sandbox not found")
# Use sandbox manager to get the correct internal URL
sandbox_manager = get_sandbox_manager()
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
@@ -365,71 +370,73 @@ def _proxy_request(
raise HTTPException(status_code=502, detail="Bad gateway")
@router.get("/sessions/{session_id}/webapp", response_model=None)
def get_webapp_root(
def _check_webapp_access(
session_id: UUID, user: User | None, db_session: Session
) -> BuildSession:
"""Check if user can access a session's webapp.
- public_global: accessible by anyone (no auth required)
- public_org: accessible by any authenticated user
- private: only accessible by the session owner
"""
session = db_session.get(BuildSession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.sharing_scope == SharingScope.PUBLIC_GLOBAL:
return session
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id:
raise HTTPException(status_code=404, detail="Session not found")
return session
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
def _offline_html_response() -> Response:
"""Return a branded Craft HTML page when the sandbox is not reachable.
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
terminal window aesthetic with Minecraft-themed typing animation.
"""
html = _OFFLINE_HTML_PATH.read_text()
return Response(content=html, status_code=503, media_type="text/html")
# Public router for webapp proxy — no authentication required
# (access controlled per-session via sharing_scope)
public_build_router = APIRouter(prefix="/build")
@public_build_router.get("/sessions/{session_id}/webapp", response_model=None)
@public_build_router.get(
"/sessions/{session_id}/webapp/{path:path}", response_model=None
)
def get_webapp(
session_id: UUID,
request: Request,
_: User = Depends(current_user),
path: str = "",
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy the root path of the webapp for a specific session."""
return _proxy_request("", request, session_id, db_session)
"""Proxy the webapp for a specific session (root and subpaths).
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
def get_webapp_path(
session_id: UUID,
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
return _proxy_request(path, request, session_id, db_session)
# Separate router for Next.js static assets at /_next/*
# This is needed because Next.js apps may reference assets with root-relative paths
# that don't get rewritten. The session_id is extracted from the Referer header.
nextjs_assets_router = APIRouter()
def _extract_session_from_referer(request: Request) -> UUID | None:
"""Extract session_id from the Referer header.
Expects Referer to contain /api/build/sessions/{session_id}/webapp
Accessible without authentication when sharing_scope is public_global.
Returns a friendly offline page when the sandbox is not running.
"""
import re
referer = request.headers.get("referer", "")
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
if match:
try:
return UUID(match.group(1))
except ValueError:
return None
return None
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
def get_nextjs_assets(
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy Next.js static assets requested at root /_next/ path.
The session_id is extracted from the Referer header since these requests
come from within the iframe context.
"""
session_id = _extract_session_from_referer(request)
if not session_id:
raise HTTPException(
status_code=400,
detail="Could not determine session from request context",
)
return _proxy_request(f"_next/{path}", request, session_id, db_session)
try:
_check_webapp_access(session_id, user, db_session)
except HTTPException as e:
if e.status_code == 401:
return RedirectResponse(url="/auth/login", status_code=302)
raise
try:
return _proxy_request(path, request, session_id, db_session)
except HTTPException as e:
if e.status_code in (502, 503, 504):
return _offline_html_response()
raise
# =============================================================================

View File

@@ -10,6 +10,7 @@ from onyx.configs.constants import MessageType
from onyx.db.enums import ArtifactType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.server.features.build.sandbox.models import (
FilesystemEntry as FileSystemEntry,
)
@@ -107,6 +108,7 @@ class SessionResponse(BaseModel):
nextjs_port: int | None
sandbox: SandboxResponse | None
artifacts: list[ArtifactResponse]
sharing_scope: SharingScope
@classmethod
def from_model(
@@ -129,6 +131,7 @@ class SessionResponse(BaseModel):
nextjs_port=session.nextjs_port,
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
sharing_scope=session.sharing_scope,
)
@@ -159,6 +162,19 @@ class SessionListResponse(BaseModel):
sessions: list[SessionResponse]
class SetSessionSharingRequest(BaseModel):
"""Request to set the sharing scope of a session."""
sharing_scope: SharingScope
class SetSessionSharingResponse(BaseModel):
"""Response after setting session sharing scope."""
session_id: str
sharing_scope: SharingScope
# ===== Message Models =====
class MessageRequest(BaseModel):
"""Request to send a message to the CLI agent."""
@@ -244,6 +260,7 @@ class WebappInfo(BaseModel):
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
status: str # Sandbox status (running, terminated, etc.)
ready: bool # Whether the NextJS dev server is actually responding
sharing_scope: SharingScope
# ===== File Upload Models =====

View File

@@ -30,6 +30,8 @@ from onyx.server.features.build.api.models import SessionListResponse
from onyx.server.features.build.api.models import SessionNameGenerateResponse
from onyx.server.features.build.api.models import SessionResponse
from onyx.server.features.build.api.models import SessionUpdateRequest
from onyx.server.features.build.api.models import SetSessionSharingRequest
from onyx.server.features.build.api.models import SetSessionSharingResponse
from onyx.server.features.build.api.models import SuggestionBubble
from onyx.server.features.build.api.models import SuggestionTheme
from onyx.server.features.build.api.models import UploadResponse
@@ -38,6 +40,7 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
from onyx.server.features.build.configs import SandboxBackend
from onyx.server.features.build.db.build_session import allocate_nextjs_port
from onyx.server.features.build.db.build_session import get_build_session
from onyx.server.features.build.db.build_session import set_build_session_sharing_scope
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
@@ -294,6 +297,25 @@ def update_session_name(
return SessionResponse.from_model(session, sandbox)
@router.patch("/{session_id}/public")
def set_session_public(
session_id: UUID,
request: SetSessionSharingRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SetSessionSharingResponse:
"""Set the sharing scope of a build session's webapp."""
updated = set_build_session_sharing_scope(
session_id, user.id, request.sharing_scope, db_session
)
if not updated:
raise HTTPException(status_code=404, detail="Session not found")
return SetSessionSharingResponse(
session_id=str(session_id),
sharing_scope=updated.sharing_scope,
)
@router.delete("/{session_id}", response_model=None)
def delete_session(
session_id: UUID,

View File

@@ -0,0 +1,110 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="refresh" content="15" />
<title>Craft — Starting up</title>
<style>
*,
*::before,
*::after {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas,
monospace;
background: linear-gradient(to bottom right, #030712, #111827, #030712);
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 1.5rem;
padding: 2rem;
}
.terminal {
width: 100%;
max-width: 580px;
border: 2px solid #374151;
border-radius: 2px;
}
.titlebar {
background: #1f2937;
padding: 0.5rem 0.75rem;
display: flex;
align-items: center;
gap: 0.5rem;
border-bottom: 1px solid #374151;
}
.btn {
width: 12px;
height: 12px;
border-radius: 2px;
flex-shrink: 0;
}
.btn-red {
background: #ef4444;
}
.btn-yellow {
background: #eab308;
}
.btn-green {
background: #22c55e;
}
.title-label {
flex: 1;
text-align: center;
font-size: 0.75rem;
color: #6b7280;
margin-right: 36px;
}
.body {
background: #111827;
padding: 1.5rem;
min-height: 200px;
font-size: 0.875rem;
color: #d1d5db;
display: flex;
align-items: flex-start;
gap: 0.375rem;
}
.prompt {
color: #10b981;
user-select: none;
}
.tagline {
font-size: 0.8125rem;
color: #4b5563;
text-align: center;
}
</style>
</head>
<body>
<div class="terminal">
<div class="titlebar">
<div class="btn btn-red"></div>
<div class="btn btn-yellow"></div>
<div class="btn btn-green"></div>
<span class="title-label">crafting_table</span>
</div>
<div class="body">
<span class="prompt">/&gt;</span>
<span>Sandbox is asleep...</span>
</div>
</div>
<p class="tagline">
Ask the owner to open their Craft session to wake it up.
</p>
</body>
</html>

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import MessageType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.db.models import Artifact
from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
@@ -159,6 +160,26 @@ def update_session_status(
logger.info(f"Updated build session {session_id} status to {status}")
def set_build_session_sharing_scope(
session_id: UUID,
user_id: UUID,
sharing_scope: SharingScope,
db_session: Session,
) -> BuildSession | None:
"""Set the sharing scope of a build session.
Only the session owner can change this setting.
Returns the updated session, or None if not found/unauthorized.
"""
session = get_build_session(session_id, user_id, db_session)
if not session:
return None
session.sharing_scope = sharing_scope
db_session.commit()
logger.info(f"Set build session {session_id} sharing_scope={sharing_scope}")
return session
def delete_build_session__no_commit(
session_id: UUID,
user_id: UUID,

View File

@@ -474,6 +474,23 @@ class SandboxManager(ABC):
"""
...
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Ensure the Next.js server is running for a session.
Default is a no-op — only meaningful for local backends that manage
process lifecycles directly (e.g., LocalSandboxManager).
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port the Next.js server should be listening on
"""
# Singleton instance cache for the factory
_sandbox_manager_instance: SandboxManager | None = None

View File

@@ -4,8 +4,9 @@ This client runs `opencode acp` directly in the sandbox pod via kubernetes exec,
using stdin/stdout for JSON-RPC communication. This bypasses the HTTP server
and uses the native ACP subprocess protocol.
This module includes comprehensive logging for debugging ACP communication.
Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true.
Each message creates an ephemeral client (start → resume_or_create_session
send_message → stop) to prevent concurrent processes from corrupting
opencode's flat file session storage.
Usage:
client = ACPExecClient(
@@ -13,12 +14,14 @@ Usage:
namespace="onyx-sandboxes",
)
client.start(cwd="/workspace")
for event in client.send_message("What files are here?"):
session_id = client.resume_or_create_session(cwd="/workspace/sessions/abc")
for event in client.send_message("What files are here?", session_id=session_id):
print(event)
client.stop()
"""
import json
import shlex
import threading
import time
from collections.abc import Generator
@@ -27,6 +30,7 @@ from dataclasses import field
from queue import Empty
from queue import Queue
from typing import Any
from typing import cast
from acp.schema import AgentMessageChunk
from acp.schema import AgentPlanUpdate
@@ -40,6 +44,7 @@ from kubernetes import client # type: ignore
from kubernetes import config
from kubernetes.stream import stream as k8s_stream # type: ignore
from kubernetes.stream.ws_client import WSClient # type: ignore
from pydantic import BaseModel
from pydantic import ValidationError
from onyx.server.features.build.api.packet_logger import get_packet_logger
@@ -100,7 +105,7 @@ class ACPClientState:
"""Internal state for the ACP client."""
initialized: bool = False
current_session: ACPSession | None = None
sessions: dict[str, ACPSession] = field(default_factory=dict)
next_request_id: int = 0
agent_capabilities: dict[str, Any] = field(default_factory=dict)
agent_info: dict[str, Any] = field(default_factory=dict)
@@ -155,16 +160,16 @@ class ACPExecClient:
self._k8s_client = client.CoreV1Api()
return self._k8s_client
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> str:
"""Start the agent process via exec and initialize a session.
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> None:
"""Start the agent process via exec and initialize the ACP connection.
Only performs the ACP `initialize` handshake. Sessions are created
separately via `resume_or_create_session()`.
Args:
cwd: Working directory for the agent
cwd: Working directory for the `opencode acp` process
timeout: Timeout for initialization
Returns:
The session ID
Raises:
RuntimeError: If startup fails
"""
@@ -173,8 +178,19 @@ class ACPExecClient:
k8s = self._get_k8s_client()
# Start opencode acp via exec
exec_command = ["opencode", "acp", "--cwd", cwd]
# Start opencode acp via exec.
# Set XDG_DATA_HOME so opencode stores session data on the shared
# workspace volume (accessible from file-sync container for snapshots)
# instead of the container-local ~/.local/share/ filesystem.
data_dir = shlex.quote(f"{cwd}/.opencode-data")
safe_cwd = shlex.quote(cwd)
exec_command = [
"/bin/sh",
"-c",
f"XDG_DATA_HOME={data_dir} exec opencode acp --cwd {safe_cwd}",
]
logger.info(f"[ACP] Starting client: pod={self._pod_name} cwd={cwd}")
try:
self._ws_client = k8s_stream(
@@ -201,15 +217,12 @@ class ACPExecClient:
# Give process a moment to start
time.sleep(0.5)
# Initialize ACP connection
# Initialize ACP connection (no session creation)
self._initialize(timeout=timeout)
# Create session
session_id = self._create_session(cwd=cwd, timeout=timeout)
return session_id
logger.info(f"[ACP] Client started: pod={self._pod_name}")
except Exception as e:
logger.error(f"[ACP] Client start failed: pod={self._pod_name} error={e}")
self.stop()
raise RuntimeError(f"Failed to start ACP exec client: {e}") from e
@@ -224,56 +237,52 @@ class ACPExecClient:
try:
if self._ws_client.is_open():
# Read available data
self._ws_client.update(timeout=0.1)
# Read stdout (channel 1)
# Read stderr - log any agent errors
stderr_data = self._ws_client.read_stderr(timeout=0.01)
if stderr_data:
logger.warning(
f"[ACP] stderr pod={self._pod_name}: "
f"{stderr_data.strip()[:500]}"
)
# Read stdout
data = self._ws_client.read_stdout(timeout=0.1)
if data:
buffer += data
# Process complete lines
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
line = line.strip()
if line:
try:
message = json.loads(line)
# Log the raw incoming message
packet_logger.log_jsonrpc_raw_message(
"IN", message, context="k8s"
)
self._response_queue.put(message)
except json.JSONDecodeError:
packet_logger.log_raw(
"JSONRPC-PARSE-ERROR-K8S",
{
"raw_line": line[:500],
"error": "JSON decode failed",
},
)
logger.warning(
f"Invalid JSON from agent: {line[:100]}"
f"[ACP] Invalid JSON from agent: "
f"{line[:100]}"
)
else:
packet_logger.log_raw(
"K8S-WEBSOCKET-CLOSED",
{"pod": self._pod_name, "namespace": self._namespace},
)
logger.warning(f"[ACP] WebSocket closed: pod={self._pod_name}")
break
except Exception as e:
if not self._stop_reader.is_set():
packet_logger.log_raw(
"K8S-READER-ERROR",
{"error": str(e), "pod": self._pod_name},
)
logger.debug(f"Reader error: {e}")
logger.warning(f"[ACP] Reader error: {e}, pod={self._pod_name}")
break
def stop(self) -> None:
"""Stop the exec session and clean up."""
session_ids = list(self._state.sessions.keys())
logger.info(
f"[ACP] Stopping client: pod={self._pod_name} " f"sessions={session_ids}"
)
self._stop_reader.set()
if self._ws_client is not None:
@@ -400,42 +409,150 @@ class ACPExecClient:
if not session_id:
raise RuntimeError("No session ID returned from session/new")
self._state.current_session = ACPSession(session_id=session_id, cwd=cwd)
self._state.sessions[session_id] = ACPSession(session_id=session_id, cwd=cwd)
logger.info(f"[ACP] Created session: acp_session={session_id} cwd={cwd}")
return session_id
def _list_sessions(self, cwd: str, timeout: float = 10.0) -> list[dict[str, Any]]:
"""List available ACP sessions, filtered by working directory.
Returns:
List of session info dicts with keys like 'sessionId', 'cwd', 'title'.
Empty list if session/list is not supported or fails.
"""
try:
request_id = self._send_request("session/list", {"cwd": cwd})
result = self._wait_for_response(request_id, timeout)
sessions = result.get("sessions", [])
logger.info(f"[ACP] session/list: {len(sessions)} sessions for cwd={cwd}")
return sessions
except Exception as e:
logger.info(f"[ACP] session/list unavailable: {e}")
return []
def _resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
"""Resume an existing ACP session.
Args:
session_id: The ACP session ID to resume
cwd: Working directory for the session
timeout: Timeout for the resume request
Returns:
The session ID
Raises:
RuntimeError: If resume fails
"""
params = {
"sessionId": session_id,
"cwd": cwd,
"mcpServers": [],
}
request_id = self._send_request("session/resume", params)
result = self._wait_for_response(request_id, timeout)
# The response should contain the session ID
resumed_id = result.get("sessionId", session_id)
self._state.sessions[resumed_id] = ACPSession(session_id=resumed_id, cwd=cwd)
logger.info(f"[ACP] Resumed session: acp_session={resumed_id} cwd={cwd}")
return resumed_id
def _try_resume_existing_session(self, cwd: str, timeout: float) -> str | None:
"""Try to find and resume an existing session for this workspace.
When multiple API server replicas connect to the same sandbox pod,
a previous replica may have already created an ACP session for this
workspace. This method discovers and resumes that session so the
agent retains conversation context.
Args:
cwd: Working directory to search for sessions
timeout: Timeout for ACP requests
Returns:
The resumed session ID, or None if no session could be resumed
"""
# List sessions for this workspace directory
sessions = self._list_sessions(cwd, timeout=min(timeout, 10.0))
if not sessions:
return None
# Pick the most recent session (first in list, assuming sorted)
target = sessions[0]
target_id = target.get("sessionId")
if not target_id:
logger.warning("[ACP] session/list returned session without sessionId")
return None
logger.info(
f"[ACP] Resuming existing session: acp_session={target_id} "
f"(found {len(sessions)})"
)
try:
return self._resume_session(target_id, cwd, timeout)
except Exception as e:
logger.warning(
f"[ACP] session/resume failed for {target_id}: {e}, "
f"falling back to session/new"
)
return None
def resume_or_create_session(self, cwd: str, timeout: float = 30.0) -> str:
"""Resume a session from opencode's on-disk storage, or create a new one.
With ephemeral clients (one process per message), this always hits disk.
Tries resume first to preserve conversation context, falls back to new.
Args:
cwd: Working directory for the session
timeout: Timeout for ACP requests
Returns:
The ACP session ID
"""
if not self._state.initialized:
raise RuntimeError("Client not initialized. Call start() first.")
# Try to resume from opencode's persisted storage
resumed_id = self._try_resume_existing_session(cwd, timeout)
if resumed_id:
return resumed_id
# Create a new session
return self._create_session(cwd=cwd, timeout=timeout)
def send_message(
self,
message: str,
session_id: str,
timeout: float = ACP_MESSAGE_TIMEOUT,
) -> Generator[ACPEvent, None, None]:
"""Send a message and stream response events.
"""Send a message to a specific session and stream response events.
Args:
message: The message content to send
session_id: The ACP session ID to send the message to
timeout: Maximum time to wait for complete response (defaults to ACP_MESSAGE_TIMEOUT env var)
Yields:
Typed ACP schema event objects
"""
if self._state.current_session is None:
raise RuntimeError("No active session. Call start() first.")
session_id = self._state.current_session.session_id
if session_id not in self._state.sessions:
raise RuntimeError(
f"Unknown session {session_id}. "
f"Known sessions: {list(self._state.sessions.keys())}"
)
packet_logger = get_packet_logger()
# Log the start of message processing
packet_logger.log_raw(
"ACP-SEND-MESSAGE-START-K8S",
{
"session_id": session_id,
"pod": self._pod_name,
"namespace": self._namespace,
"message_preview": (
message[:200] + "..." if len(message) > 200 else message
),
"timeout": timeout,
},
logger.info(
f"[ACP] Sending prompt: "
f"acp_session={session_id} pod={self._pod_name} "
f"queue_backlog={self._response_queue.qsize()}"
)
prompt_content = [{"type": "text", "text": message}]
@@ -446,44 +563,53 @@ class ACPExecClient:
request_id = self._send_request("session/prompt", params)
start_time = time.time()
last_event_time = time.time() # Track time since last event for keepalive
last_event_time = time.time()
events_yielded = 0
keepalive_count = 0
completion_reason = "unknown"
while True:
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
packet_logger.log_raw(
"ACP-TIMEOUT-K8S",
{
"session_id": session_id,
"elapsed_ms": (time.time() - start_time) * 1000,
},
completion_reason = "timeout"
logger.warning(
f"[ACP] Prompt timeout: "
f"acp_session={session_id} events={events_yielded}, "
f"sending session/cancel"
)
try:
self.cancel(session_id=session_id)
except Exception as cancel_err:
logger.warning(
f"[ACP] session/cancel failed on timeout: {cancel_err}"
)
yield Error(code=-1, message="Timeout waiting for response")
break
try:
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
last_event_time = time.time() # Reset keepalive timer on event
last_event_time = time.time()
except Empty:
# Check if we need to send an SSE keepalive
# Send SSE keepalive if idle
idle_time = time.time() - last_event_time
if idle_time >= SSE_KEEPALIVE_INTERVAL:
packet_logger.log_raw(
"SSE-KEEPALIVE-YIELD",
{
"session_id": session_id,
"idle_seconds": idle_time,
},
)
keepalive_count += 1
yield SSEKeepalive()
last_event_time = time.time() # Reset after yielding keepalive
last_event_time = time.time()
continue
# Check for response to our prompt request
if message_data.get("id") == request_id:
# Check for JSON-RPC response to our prompt request.
msg_id = message_data.get("id")
is_response = "method" not in message_data and (
msg_id == request_id
or (msg_id is not None and str(msg_id) == str(request_id))
)
if is_response:
completion_reason = "jsonrpc_response"
if "error" in message_data:
error_data = message_data["error"]
completion_reason = "jsonrpc_error"
logger.warning(f"[ACP] Prompt error: {error_data}")
packet_logger.log_jsonrpc_response(
request_id, error=error_data, context="k8s"
)
@@ -498,26 +624,16 @@ class ACPExecClient:
)
try:
prompt_response = PromptResponse.model_validate(result)
packet_logger.log_acp_event_yielded(
"prompt_response", prompt_response
)
events_yielded += 1
yield prompt_response
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"type": "prompt_response", "error": str(e)},
)
logger.error(f"[ACP] PromptResponse validation failed: {e}")
# Log completion summary
elapsed_ms = (time.time() - start_time) * 1000
packet_logger.log_raw(
"ACP-SEND-MESSAGE-COMPLETE-K8S",
{
"session_id": session_id,
"events_yielded": events_yielded,
"elapsed_ms": elapsed_ms,
},
logger.info(
f"[ACP] Prompt complete: "
f"reason={completion_reason} acp_session={session_id} "
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
)
break
@@ -526,25 +642,29 @@ class ACPExecClient:
params_data = message_data.get("params", {})
update = params_data.get("update", {})
# Log the notification
packet_logger.log_jsonrpc_notification(
"session/update",
{"update_type": update.get("sessionUpdate")},
context="k8s",
)
prompt_complete = False
for event in self._process_session_update(update):
events_yielded += 1
# Log each yielded event
event_type = self._get_event_type_name(event)
packet_logger.log_acp_event_yielded(event_type, event)
yield event
if isinstance(event, PromptResponse):
prompt_complete = True
break
if prompt_complete:
completion_reason = "prompt_response_via_notification"
elapsed_ms = (time.time() - start_time) * 1000
logger.info(
f"[ACP] Prompt complete: "
f"reason={completion_reason} acp_session={session_id} "
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
)
break
# Handle requests from agent - send error response
elif "method" in message_data and "id" in message_data:
packet_logger.log_raw(
"ACP-UNSUPPORTED-REQUEST-K8S",
{"method": message_data["method"], "id": message_data["id"]},
logger.debug(
f"[ACP] Unsupported agent request: "
f"method={message_data['method']}"
)
self._send_error_response(
message_data["id"],
@@ -552,113 +672,49 @@ class ACPExecClient:
f"Method not supported: {message_data['method']}",
)
def _get_event_type_name(self, event: ACPEvent) -> str:
"""Get the type name for an ACP event."""
if isinstance(event, AgentMessageChunk):
return "agent_message_chunk"
elif isinstance(event, AgentThoughtChunk):
return "agent_thought_chunk"
elif isinstance(event, ToolCallStart):
return "tool_call_start"
elif isinstance(event, ToolCallProgress):
return "tool_call_progress"
elif isinstance(event, AgentPlanUpdate):
return "agent_plan_update"
elif isinstance(event, CurrentModeUpdate):
return "current_mode_update"
elif isinstance(event, PromptResponse):
return "prompt_response"
elif isinstance(event, Error):
return "error"
elif isinstance(event, SSEKeepalive):
return "sse_keepalive"
return "unknown"
else:
logger.warning(
f"[ACP] Unhandled message: "
f"id={message_data.get('id')} "
f"method={message_data.get('method')} "
f"keys={list(message_data.keys())}"
)
def _process_session_update(
self, update: dict[str, Any]
) -> Generator[ACPEvent, None, None]:
"""Process a session/update notification and yield typed ACP schema objects."""
update_type = update.get("sessionUpdate")
packet_logger = get_packet_logger()
if not isinstance(update_type, str):
return
if update_type == "agent_message_chunk":
# Map update types to their ACP schema classes.
# Note: prompt_response is included because ACP sometimes sends it as a
# notification WITHOUT a corresponding JSON-RPC response. We accept
# either signal as turn completion (first one wins).
type_map: dict[str, type[BaseModel]] = {
"agent_message_chunk": AgentMessageChunk,
"agent_thought_chunk": AgentThoughtChunk,
"tool_call": ToolCallStart,
"tool_call_update": ToolCallProgress,
"plan": AgentPlanUpdate,
"current_mode_update": CurrentModeUpdate,
"prompt_response": PromptResponse,
}
model_class = type_map.get(update_type)
if model_class is not None:
try:
yield AgentMessageChunk.model_validate(update)
yield cast(ACPEvent, model_class.model_validate(update))
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "agent_thought_chunk":
try:
yield AgentThoughtChunk.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "user_message_chunk":
# Echo of user message - skip but log
packet_logger.log_raw(
"ACP-SKIPPED-UPDATE-K8S", {"type": "user_message_chunk"}
)
elif update_type == "tool_call":
try:
yield ToolCallStart.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "tool_call_update":
try:
yield ToolCallProgress.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "plan":
try:
yield AgentPlanUpdate.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "current_mode_update":
try:
yield CurrentModeUpdate.model_validate(update)
except ValidationError as e:
packet_logger.log_raw(
"ACP-VALIDATION-ERROR-K8S",
{"update_type": update_type, "error": str(e), "update": update},
)
elif update_type == "available_commands_update":
# Skip command updates
packet_logger.log_raw(
"ACP-SKIPPED-UPDATE-K8S", {"type": "available_commands_update"}
)
elif update_type == "session_info_update":
# Skip session info updates
packet_logger.log_raw(
"ACP-SKIPPED-UPDATE-K8S", {"type": "session_info_update"}
)
else:
# Unknown update types are logged
packet_logger.log_raw(
"ACP-UNKNOWN-UPDATE-TYPE-K8S",
{"update_type": update_type, "update": update},
)
logger.warning(f"[ACP] Validation error for {update_type}: {e}")
elif update_type not in (
"user_message_chunk",
"available_commands_update",
"session_info_update",
"usage_update",
):
logger.debug(f"[ACP] Unknown update type: {update_type}")
def _send_error_response(self, request_id: int, code: int, message: str) -> None:
"""Send an error response to an agent request."""
@@ -673,15 +729,24 @@ class ACPExecClient:
self._ws_client.write_stdin(json.dumps(response) + "\n")
def cancel(self) -> None:
"""Cancel the current operation."""
if self._state.current_session is None:
return
def cancel(self, session_id: str | None = None) -> None:
"""Cancel the current operation on a session.
self._send_notification(
"session/cancel",
{"sessionId": self._state.current_session.session_id},
)
Args:
session_id: The ACP session ID to cancel. If None, cancels all sessions.
"""
if session_id:
if session_id in self._state.sessions:
self._send_notification(
"session/cancel",
{"sessionId": session_id},
)
else:
for sid in self._state.sessions:
self._send_notification(
"session/cancel",
{"sessionId": sid},
)
def health_check(self, timeout: float = 5.0) -> bool: # noqa: ARG002
"""Check if we can exec into the pod."""
@@ -707,13 +772,6 @@ class ACPExecClient:
"""Check if the exec session is running."""
return self._ws_client is not None and self._ws_client.is_open()
@property
def session_id(self) -> str | None:
"""Get the current session ID, if any."""
if self._state.current_session:
return self._state.current_session.session_id
return None
def __enter__(self) -> "ACPExecClient":
"""Context manager entry."""
return self

View File

@@ -50,6 +50,7 @@ from pathlib import Path
from uuid import UUID
from uuid import uuid4
from acp.schema import PromptResponse
from kubernetes import client # type: ignore
from kubernetes import config
from kubernetes.client.rest import ApiException # type: ignore
@@ -97,6 +98,10 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# API server pod hostname — used to identify which replica is handling a request.
# In K8s, HOSTNAME is set to the pod name (e.g., "api-server-dpgg7").
_API_SERVER_HOSTNAME = os.environ.get("HOSTNAME", "unknown")
# Constants for pod configuration
# Note: Next.js ports are dynamically allocated from SANDBOX_NEXTJS_PORT_START to
# SANDBOX_NEXTJS_PORT_END range, with one port per session.
@@ -1156,7 +1161,9 @@ done
def terminate(self, sandbox_id: UUID) -> None:
"""Terminate a sandbox and clean up Kubernetes resources.
Deletes the Service and Pod for the sandbox.
Removes session mappings for this sandbox, then deletes the
Service and Pod. ACP clients are ephemeral (created per message),
so there's nothing to stop here.
Args:
sandbox_id: The sandbox ID to terminate
@@ -1395,7 +1402,8 @@ echo "Session workspace setup complete"
) -> None:
"""Clean up a session workspace (on session delete).
Executes kubectl exec to remove the session directory.
Removes the ACP session mapping and executes kubectl exec to remove
the session directory. The shared ACP client persists for other sessions.
Args:
sandbox_id: The sandbox ID
@@ -1464,6 +1472,7 @@ echo "Session cleanup complete"
the snapshot and upload to S3. Captures:
- sessions/$session_id/outputs/ (generated artifacts, web apps)
- sessions/$session_id/attachments/ (user uploaded files)
- sessions/$session_id/.opencode-data/ (opencode session data for resumption)
Args:
sandbox_id: The sandbox ID
@@ -1488,9 +1497,10 @@ echo "Session cleanup complete"
f"{session_id_str}/{snapshot_id}.tar.gz"
)
# Exec into pod to create and upload snapshot (outputs + attachments)
# Uses s5cmd pipe to stream tar.gz directly to S3
# Only snapshot if outputs/ exists. Include attachments/ only if non-empty.
# Create tar and upload to S3 via file-sync container.
# .opencode-data/ is already on the shared workspace volume because we set
# XDG_DATA_HOME to the session directory when starting opencode (see
# ACPExecClient.start()). No cross-container copy needed.
exec_command = [
"/bin/sh",
"-c",
@@ -1503,6 +1513,7 @@ if [ ! -d outputs ]; then
fi
dirs="outputs"
[ -d attachments ] && [ "$(ls -A attachments 2>/dev/null)" ] && dirs="$dirs attachments"
[ -d .opencode-data ] && [ "$(ls -A .opencode-data 2>/dev/null)" ] && dirs="$dirs .opencode-data"
tar -czf - $dirs | /s5cmd pipe {s3_path}
echo "SNAPSHOT_CREATED"
""",
@@ -1624,6 +1635,7 @@ echo "SNAPSHOT_CREATED"
Steps:
1. Exec s5cmd cat in file-sync container to stream snapshot from S3
2. Pipe directly to tar for extraction in the shared workspace volume
(.opencode-data/ is restored automatically since XDG_DATA_HOME points here)
3. Regenerate configuration files (AGENTS.md, opencode.json, files symlink)
4. Start the NextJS dev server
@@ -1807,6 +1819,41 @@ echo "Session config regeneration complete"
)
return exec_client.health_check(timeout=timeout)
def _create_ephemeral_acp_client(
self, sandbox_id: UUID, session_path: str
) -> ACPExecClient:
"""Create a new ephemeral ACP client for a single message exchange.
Each call starts a fresh `opencode acp` process in the sandbox pod.
The process is short-lived — stopped after the message completes.
This prevents the bug where multiple long-lived processes (one per
API replica) operate on the same session's flat file storage
concurrently, causing the JSON-RPC response to be silently lost.
Args:
sandbox_id: The sandbox ID
session_path: Working directory for the session (e.g. /workspace/sessions/{id}).
XDG_DATA_HOME is set relative to this so opencode's session data
lives inside the snapshot directory.
Returns:
A running ACPExecClient (caller must stop it when done)
"""
pod_name = self._get_pod_name(str(sandbox_id))
acp_client = ACPExecClient(
pod_name=pod_name,
namespace=self._namespace,
container="sandbox",
)
acp_client.start(cwd=session_path)
logger.info(
f"[SANDBOX-ACP] Created ephemeral ACP client: "
f"sandbox={sandbox_id} pod={pod_name} "
f"api_pod={_API_SERVER_HOSTNAME}"
)
return acp_client
def send_message(
self,
sandbox_id: UUID,
@@ -1815,8 +1862,12 @@ echo "Session config regeneration complete"
) -> Generator[ACPEvent, None, None]:
"""Send a message to the CLI agent and stream ACP events.
Runs `opencode acp` via kubectl exec in the sandbox pod.
The agent runs in the session-specific workspace.
Creates an ephemeral `opencode acp` process for each message.
The process resumes the session from opencode's on-disk storage,
handles the prompt, then is stopped. This ensures only one process
operates on a session's flat files at a time, preventing the bug
where multiple long-lived processes (one per API replica) corrupt
each other's in-memory state.
Args:
sandbox_id: The sandbox ID
@@ -1827,67 +1878,103 @@ echo "Session config regeneration complete"
Typed ACP schema event objects
"""
packet_logger = get_packet_logger()
pod_name = self._get_pod_name(str(sandbox_id))
session_path = f"/workspace/sessions/{session_id}"
# Log ACP client creation
packet_logger.log_acp_client_start(
sandbox_id, session_id, session_path, context="k8s"
)
# Create an ephemeral ACP client for this message
acp_client = self._create_ephemeral_acp_client(sandbox_id, session_path)
exec_client = ACPExecClient(
pod_name=pod_name,
namespace=self._namespace,
container="sandbox",
)
# Log the send_message call at sandbox manager level
packet_logger.log_session_start(session_id, sandbox_id, message)
events_count = 0
try:
exec_client.start(cwd=session_path)
for event in exec_client.send_message(message):
events_count += 1
yield event
# Resume (or create) the ACP session from opencode's on-disk storage
acp_session_id = acp_client.resume_or_create_session(cwd=session_path)
# Log successful completion
packet_logger.log_session_end(
session_id, success=True, events_count=events_count
logger.info(
f"[SANDBOX-ACP] Sending message: "
f"session={session_id} acp_session={acp_session_id} "
f"api_pod={_API_SERVER_HOSTNAME}"
)
except GeneratorExit:
# Generator was closed by consumer (client disconnect, timeout, broken pipe)
# This is the most common failure mode for SSE streaming
packet_logger.log_session_end(
session_id,
success=False,
error="GeneratorExit: Client disconnected or stream closed by consumer",
events_count=events_count,
)
raise
except Exception as e:
# Log failure from normal exceptions
packet_logger.log_session_end(
session_id,
success=False,
error=f"Exception: {str(e)}",
events_count=events_count,
)
raise
except BaseException as e:
# Log failure from other base exceptions (SystemExit, KeyboardInterrupt, etc.)
exception_type = type(e).__name__
packet_logger.log_session_end(
session_id,
success=False,
error=f"{exception_type}: {str(e) if str(e) else 'System-level interruption'}",
events_count=events_count,
)
raise
# Log the send_message call at sandbox manager level
packet_logger.log_session_start(session_id, sandbox_id, message)
events_count = 0
got_prompt_response = False
try:
for event in acp_client.send_message(
message, session_id=acp_session_id
):
events_count += 1
if isinstance(event, PromptResponse):
got_prompt_response = True
yield event
logger.info(
f"[SANDBOX-ACP] send_message completed: "
f"session={session_id} events={events_count} "
f"got_prompt_response={got_prompt_response}"
)
packet_logger.log_session_end(
session_id, success=True, events_count=events_count
)
except GeneratorExit:
logger.warning(
f"[SANDBOX-ACP] GeneratorExit: session={session_id} "
f"events={events_count}, sending session/cancel"
)
try:
acp_client.cancel(session_id=acp_session_id)
except Exception as cancel_err:
logger.warning(
f"[SANDBOX-ACP] session/cancel failed on GeneratorExit: "
f"{cancel_err}"
)
packet_logger.log_session_end(
session_id,
success=False,
error="GeneratorExit: Client disconnected or stream closed by consumer",
events_count=events_count,
)
raise
except Exception as e:
logger.error(
f"[SANDBOX-ACP] Exception: session={session_id} "
f"events={events_count} error={e}, sending session/cancel"
)
try:
acp_client.cancel(session_id=acp_session_id)
except Exception as cancel_err:
logger.warning(
f"[SANDBOX-ACP] session/cancel failed on Exception: "
f"{cancel_err}"
)
packet_logger.log_session_end(
session_id,
success=False,
error=f"Exception: {str(e)}",
events_count=events_count,
)
raise
except BaseException as e:
logger.error(
f"[SANDBOX-ACP] {type(e).__name__}: session={session_id} "
f"error={e}"
)
packet_logger.log_session_end(
session_id,
success=False,
error=f"{type(e).__name__}: {str(e) if str(e) else 'System-level interruption'}",
events_count=events_count,
)
raise
finally:
exec_client.stop()
# Log client stop
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
# Always stop the ephemeral ACP client to kill the opencode process.
# This ensures no stale processes linger in the sandbox container.
try:
acp_client.stop()
except Exception as e:
logger.warning(
f"[SANDBOX-ACP] Failed to stop ephemeral ACP client: "
f"session={session_id} error={e}"
)
def list_directory(
self, sandbox_id: UUID, session_id: UUID, path: str

View File

@@ -15,6 +15,8 @@ from collections.abc import Generator
from pathlib import Path
from uuid import UUID
import httpx
from onyx.db.enums import SandboxStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.configs import DEMO_DATA_PATH
@@ -35,6 +37,7 @@ from onyx.server.features.build.sandbox.models import LLMProviderConfig
from onyx.server.features.build.sandbox.models import SandboxInfo
from onyx.server.features.build.sandbox.models import SnapshotResult
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import ThreadSafeSet
logger = setup_logger()
@@ -89,9 +92,17 @@ class LocalSandboxManager(SandboxManager):
self._acp_clients: dict[tuple[UUID, UUID], ACPAgentClient] = {}
# Track Next.js processes - keyed by (sandbox_id, session_id) tuple
# Used for clean shutdown when sessions are deleted
# Used for clean shutdown when sessions are deleted.
# Mutated from background threads; all access must hold _nextjs_lock.
self._nextjs_processes: dict[tuple[UUID, UUID], subprocess.Popen[bytes]] = {}
# Track sessions currently being (re)started - prevents concurrent restarts.
# ThreadSafeSet allows atomic check-and-add without holding _nextjs_lock.
self._nextjs_starting: ThreadSafeSet[tuple[UUID, UUID]] = ThreadSafeSet()
# Lock guarding _nextjs_processes (shared across sessions; hold briefly only)
self._nextjs_lock = threading.Lock()
# Validate templates exist (raises RuntimeError if missing)
self._validate_templates()
@@ -326,16 +337,18 @@ class LocalSandboxManager(SandboxManager):
RuntimeError: If termination fails
"""
# Stop all Next.js processes for this sandbox (keyed by (sandbox_id, session_id))
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
with self._nextjs_lock:
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
for key, process in processes_to_stop:
session_id = key[1]
try:
self._stop_nextjs_process(process, session_id)
del self._nextjs_processes[key]
with self._nextjs_lock:
self._nextjs_processes.pop(key, None)
except Exception as e:
logger.warning(
f"Failed to stop Next.js for sandbox {sandbox_id}, "
@@ -516,7 +529,8 @@ class LocalSandboxManager(SandboxManager):
web_dir, nextjs_port
)
# Store process for clean shutdown on session delete
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
with self._nextjs_lock:
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
logger.info("Next.js server started successfully")
# Setup venv and AGENTS.md
@@ -575,7 +589,8 @@ class LocalSandboxManager(SandboxManager):
"""
# Stop Next.js dev server - try stored process first, then fallback to port lookup
process_key = (sandbox_id, session_id)
nextjs_process = self._nextjs_processes.pop(process_key, None)
with self._nextjs_lock:
nextjs_process = self._nextjs_processes.pop(process_key, None)
if nextjs_process is not None:
self._stop_nextjs_process(nextjs_process, session_id)
elif nextjs_port is not None:
@@ -766,6 +781,85 @@ class LocalSandboxManager(SandboxManager):
outputs_path = session_path / "outputs"
return outputs_path.exists()
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Start Next.js server for a session if not already running.
Called when the server is detected as unreachable (e.g., after API server restart).
Returns immediately — the actual startup runs in a background daemon thread.
A per-session guard prevents concurrent restarts from racing.
Lock design: _nextjs_lock is shared across ALL sessions. Holding it during
httpx (1s) or start_nextjs_server (several seconds) would block every other
session's status checks and restarts. We only hold the lock for fast
in-memory ops (dict get, check_and_add). The slow I/O runs in the background
thread without holding any lock.
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port number for the Next.js server
"""
process_key = (sandbox_id, session_id)
with self._nextjs_lock:
existing = self._nextjs_processes.get(process_key)
if existing is not None and existing.poll() is None:
return
# Atomic check-and-add: returns True if already in set (another thread is starting)
if self._nextjs_starting.check_and_add(process_key):
return
def _start_in_background() -> None:
try:
# Port check in background to avoid blocking the main thread
try:
with httpx.Client(timeout=1.0) as client:
client.get(f"http://localhost:{nextjs_port}")
logger.info(
f"Port {nextjs_port} already alive for session {session_id} "
"(orphan process) — skipping restart"
)
return
except Exception:
pass # Port is dead; proceed with restart
logger.info(
f"Starting Next.js for session {session_id} on port {nextjs_port}"
)
sandbox_path = self._get_sandbox_path(sandbox_id)
web_dir = self._directory_manager.get_web_path(
sandbox_path, str(session_id)
)
if not web_dir.exists():
logger.warning(
f"Web dir missing for session {session_id}: {web_dir}"
"cannot restart Next.js"
)
return
process = self._process_manager.start_nextjs_server(
web_dir, nextjs_port
)
with self._nextjs_lock:
self._nextjs_processes[process_key] = process
logger.info(
f"Auto-restarted Next.js for session {session_id} "
f"on port {nextjs_port}"
)
except Exception as e:
logger.error(
f"Failed to auto-restart Next.js for session {session_id}: {e}"
)
finally:
self._nextjs_starting.discard(process_key)
threading.Thread(target=_start_in_background, daemon=True).start()
def restore_snapshot(
self,
sandbox_id: UUID,

View File

@@ -1,10 +0,0 @@
"""Celery tasks for sandbox management."""
from onyx.server.features.build.sandbox.tasks.tasks import (
cleanup_idle_sandboxes_task,
) # noqa: F401
from onyx.server.features.build.sandbox.tasks.tasks import (
sync_sandbox_files,
) # noqa: F401
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]

View File

@@ -1765,6 +1765,7 @@ class SessionManager:
"webapp_url": None,
"status": "no_sandbox",
"ready": False,
"sharing_scope": session.sharing_scope,
}
# Return the proxy URL - the proxy handles routing to the correct sandbox
@@ -1777,11 +1778,21 @@ class SessionManager:
# Quick health check: can the API server reach the NextJS dev server?
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
# If not ready, ask the sandbox manager to ensure Next.js is running.
# For the local backend this triggers a background restart so that the
# frontend poll loop eventually sees ready=True without the user having
# to manually recreate the session.
if not ready:
self._sandbox_manager.ensure_nextjs_running(
sandbox.id, session_id, session.nextjs_port
)
return {
"has_webapp": session.nextjs_port is not None,
"webapp_url": webapp_url,
"status": sandbox.status.value,
"ready": ready,
"sharing_scope": session.sharing_scope,
}
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:

View File

@@ -30,17 +30,28 @@ OPENSEARCH_NOT_ENABLED_MESSAGE = (
"OpenSearch indexing must be enabled to use this feature."
)
MIGRATION_STATUS_MESSAGE = (
"Our records indicate that the transition to OpenSearch is still in progress. "
"OpenSearch retrieval is necessary to use this feature. "
"You can still use Document Sets, though! "
"If you would like to manually switch to OpenSearch, "
'Go to the "Document Index Migration" section in the Admin panel.'
)
router = APIRouter(prefix=HIERARCHY_NODES_PREFIX)
def _require_opensearch(db_session: Session) -> None:
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX or not get_opensearch_retrieval_state(
db_session
):
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
raise HTTPException(
status_code=403,
detail=OPENSEARCH_NOT_ENABLED_MESSAGE,
)
if not get_opensearch_retrieval_state(db_session):
raise HTTPException(
status_code=403,
detail=MIGRATION_STATUS_MESSAGE,
)
def _get_user_access_info(

View File

@@ -8,6 +8,7 @@ import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
@@ -56,7 +57,7 @@ def is_version_gte(v1: str, v2: str) -> bool:
def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry]:
"""Parse MDX content into ReleaseNoteEntry objects for versions >= __version__."""
"""Parse MDX content into ReleaseNoteEntry objects."""
all_entries = []
update_pattern = (
@@ -82,6 +83,12 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
if not all_entries:
raise ValueError("Could not parse any release note entries from MDX.")
if INSTANCE_TYPE == "cloud":
# Cloud often runs ahead of docs release tags; always notify on latest release.
return sorted(
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
)[:1]
# Filter to valid versions >= __version__
if __version__ and is_valid_version(__version__):
entries = [

View File

@@ -310,7 +310,7 @@ def list_llm_providers(
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session=db_session,
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
flow_type_filter=[],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
@@ -568,9 +568,7 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
all_providers = fetch_existing_llm_providers(db_session, [])
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN

View File

@@ -26,13 +26,17 @@ def get_opensearch_migration_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> OpenSearchMigrationStatusResponse:
total_chunks_migrated, created_at, migration_completed_at = (
get_opensearch_migration_state(db_session)
)
(
total_chunks_migrated,
created_at,
migration_completed_at,
approx_chunk_count_in_vespa,
) = get_opensearch_migration_state(db_session)
return OpenSearchMigrationStatusResponse(
total_chunks_migrated=total_chunks_migrated,
created_at=created_at,
migration_completed_at=migration_completed_at,
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)

View File

@@ -8,6 +8,7 @@ class OpenSearchMigrationStatusResponse(BaseModel):
total_chunks_migrated: int
created_at: datetime | None
migration_completed_at: datetime | None
approx_chunk_count_in_vespa: int | None
class OpenSearchRetrievalStatusRequest(BaseModel):

View File

@@ -608,7 +608,8 @@ def list_all_users_basic_info(
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if include_api_keys or not is_api_key_email_address(user.email)
if user.role != UserRole.SLACK_USER
and (include_api_keys or not is_api_key_email_address(user.email))
]

View File

View File

@@ -0,0 +1,241 @@
"""SQLAlchemy connection pool Prometheus metrics.
Provides production-grade visibility into database connection pool state:
- Pool state gauges (checked-out, idle, overflow, configured size)
- Pool lifecycle counters (checkouts, checkins, creates, invalidations, timeouts)
- Per-endpoint connection attribution (which endpoints hold connections, for how long)
Metrics are collected via two mechanisms:
1. A custom Prometheus Collector that reads pool snapshots on each /metrics scrape
2. SQLAlchemy pool event listeners (checkout, checkin, connect, invalidate) for
counters, histograms, and attribution
"""
import time
from fastapi import Request
from fastapi.responses import JSONResponse
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client.registry import REGISTRY
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.engine.interfaces import DBAPIConnection
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.pool import ConnectionPoolEntry
from sqlalchemy.pool import PoolProxiedConnection
from sqlalchemy.pool import QueuePool
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
logger = setup_logger()
# --- Pool lifecycle counters (event-driven) ---
_checkout_total = Counter(
"onyx_db_pool_checkout_total",
"Total connection checkouts from the pool",
["engine"],
)
_checkin_total = Counter(
"onyx_db_pool_checkin_total",
"Total connection checkins to the pool",
["engine"],
)
_connections_created_total = Counter(
"onyx_db_pool_connections_created_total",
"Total new database connections created",
["engine"],
)
_invalidations_total = Counter(
"onyx_db_pool_invalidations_total",
"Total connection invalidations",
["engine"],
)
_checkout_timeout_total = Counter(
"onyx_db_pool_checkout_timeout_total",
"Total connection checkout timeouts",
["engine"],
)
# --- Per-endpoint attribution (event-driven) ---
_connections_held = Gauge(
"onyx_db_connections_held_by_endpoint",
"Number of DB connections currently held, by endpoint and engine",
["handler", "engine"],
)
_hold_seconds = Histogram(
"onyx_db_connection_hold_seconds",
"Duration a DB connection is held by an endpoint",
["handler", "engine"],
)
def pool_timeout_handler(
request: Request, # noqa: ARG001
exc: Exception,
) -> JSONResponse:
"""Increment the checkout timeout counter and return 503."""
_checkout_timeout_total.labels(engine="unknown").inc()
return JSONResponse(
status_code=503,
content={
"detail": "Database connection pool timeout",
"error": str(exc),
},
)
class PoolStateCollector(Collector):
"""Custom Prometheus collector that reads QueuePool state on each scrape.
Uses pool.checkedout(), pool.checkedin(), pool.overflow(), and pool.size()
for an atomic snapshot of pool state. Registered engines are stored as
(label, pool) tuples to avoid holding references to the full Engine.
"""
def __init__(self) -> None:
self._pools: list[tuple[str, QueuePool]] = []
def add_pool(self, label: str, pool: QueuePool) -> None:
self._pools.append((label, pool))
def collect(self) -> list[GaugeMetricFamily]:
checked_out = GaugeMetricFamily(
"onyx_db_pool_checked_out",
"Currently checked-out connections",
labels=["engine"],
)
checked_in = GaugeMetricFamily(
"onyx_db_pool_checked_in",
"Idle connections available in the pool",
labels=["engine"],
)
overflow = GaugeMetricFamily(
"onyx_db_pool_overflow",
"Current overflow connections beyond pool_size",
labels=["engine"],
)
size = GaugeMetricFamily(
"onyx_db_pool_size",
"Configured pool size",
labels=["engine"],
)
for label, pool in self._pools:
checked_out.add_metric([label], pool.checkedout())
checked_in.add_metric([label], pool.checkedin())
overflow.add_metric([label], pool.overflow())
size.add_metric([label], pool.size())
return [checked_out, checked_in, overflow, size]
def describe(self) -> list[GaugeMetricFamily]:
# Return empty to mark this as an "unchecked" collector. Prometheus
# skips upfront descriptor validation and just calls collect() at
# scrape time. Required because our metrics are dynamic (engine
# labels depend on which engines are registered at runtime).
return []
def _register_pool_events(engine: Engine, label: str) -> None:
"""Attach pool event listeners for metrics collection.
Listens to checkout, checkin, connect, and invalidate events.
Stores per-connection metadata on connection_record.info for attribution.
"""
@event.listens_for(engine, "checkout")
def on_checkout(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
conn_proxy: PoolProxiedConnection, # noqa: ARG001
) -> None:
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
conn_record.info["_metrics_endpoint"] = handler
conn_record.info["_metrics_checkout_time"] = time.monotonic()
_checkout_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).inc()
@event.listens_for(engine, "checkin")
def on_checkin(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
) -> None:
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
_checkin_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler, engine=label).observe(
time.monotonic() - start
)
@event.listens_for(engine, "connect")
def on_connect(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry, # noqa: ARG001
) -> None:
_connections_created_total.labels(engine=label).inc()
@event.listens_for(engine, "invalidate")
def on_invalidate(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
exception: BaseException | None, # noqa: ARG001
) -> None:
_invalidations_total.labels(engine=label).inc()
# Defensively clean up the held-connections gauge in case checkin
# doesn't fire after invalidation (e.g. hard pool shutdown).
handler = conn_record.info.pop("_metrics_endpoint", None)
start = conn_record.info.pop("_metrics_checkout_time", None)
if handler:
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
time.monotonic() - start
)
def setup_postgres_connection_pool_metrics(
engines: dict[str, Engine | AsyncEngine],
) -> None:
"""Register pool metrics for all provided engines.
Args:
engines: Mapping of engine label to Engine or AsyncEngine.
Example: {"sync": sync_engine, "async": async_engine, "readonly": ro_engine}
Engines using NullPool are skipped (no pool state to monitor).
For AsyncEngine, events are registered on the underlying sync_engine.
"""
collector = PoolStateCollector()
for label, engine in engines.items():
# Resolve async engines to their underlying sync engine
sync_engine = engine.sync_engine if isinstance(engine, AsyncEngine) else engine
pool = sync_engine.pool
if not isinstance(pool, QueuePool):
logger.info(
f"Skipping pool metrics for engine '{label}' "
f"({type(pool).__name__} — no pool state)"
)
continue
collector.add_pool(label, pool)
_register_pool_events(sync_engine, label)
logger.info(f"Registered pool metrics for engine '{label}'")
REGISTRY.register(collector)

View File

@@ -0,0 +1,64 @@
"""Prometheus metrics setup for the Onyx API server.
Orchestrates HTTP request instrumentation via ``prometheus-fastapi-instrumentator``:
- Request count, latency histograms, in-progress gauges
- Pool checkout timeout exception handler
- Custom metric callbacks (e.g. slow request counting)
SQLAlchemy connection pool metrics are registered separately via
``setup_postgres_connection_pool_metrics`` during application lifespan
(after engines are created).
"""
from prometheus_fastapi_instrumentator import Instrumentator
from sqlalchemy.exc import TimeoutError as SATimeoutError
from starlette.applications import Starlette
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
from onyx.server.metrics.slow_requests import slow_request_callback
_EXCLUDED_HANDLERS = [
"/health",
"/metrics",
"/openapi.json",
]
# Denser buckets for per-handler latency histograms. The instrumentator's
# default (0.1, 0.5, 1) is too coarse for meaningful P95/P99 computation.
_LATENCY_BUCKETS = (
0.01,
0.025,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
)
def setup_prometheus_metrics(app: Starlette) -> None:
"""Initialize HTTP request metrics for the Onyx API server.
Must be called in ``get_application()`` BEFORE the app starts, because
the instrumentator adds middleware via ``app.add_middleware()``.
Args:
app: The FastAPI/Starlette application to instrument.
"""
app.add_exception_handler(SATimeoutError, pool_timeout_handler)
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=False,
should_group_untemplated=True,
should_instrument_requests_inprogress=True,
inprogress_labels=True,
excluded_handlers=_EXCLUDED_HANDLERS,
)
instrumentator.add(slow_request_callback)
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)

View File

@@ -0,0 +1,31 @@
"""Slow request counter metric.
Increments a counter whenever a request exceeds a configurable duration
threshold. Useful for identifying endpoints that regularly take too long.
"""
import os
from prometheus_client import Counter
from prometheus_fastapi_instrumentator.metrics import Info
SLOW_REQUEST_THRESHOLD_SECONDS: float = max(
0.0,
float(os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")),
)
_slow_requests = Counter(
"onyx_api_slow_requests_total",
"Total requests exceeding the slow request threshold",
["method", "handler", "status"],
)
def slow_request_callback(info: Info) -> None:
"""Increment slow request counter when duration exceeds threshold."""
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
_slow_requests.labels(
method=info.method,
handler=info.modified_handler,
status=info.modified_status,
).inc()

View File

@@ -349,6 +349,7 @@ def get_chat_session(
shared_status=chat_session.shared_status,
current_temperature_override=chat_session.temperature_override,
deleted=chat_session.deleted,
owner_name=chat_session.user.personal_name if chat_session.user else None,
# Packets are now directly serialized as Packet Pydantic models
packets=replay_packet_lists,
)

View File

@@ -224,6 +224,7 @@ class ChatSessionDetailResponse(BaseModel):
current_alternate_model: str | None
current_temperature_override: float | None
deleted: bool = False
owner_name: str | None = None
packets: list[list[Packet]]

View File

@@ -55,6 +55,7 @@ class Settings(BaseModel):
gpu_enabled: bool | None = None
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None
invite_only_enabled: bool = False
deep_research_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime

View File

@@ -47,6 +47,7 @@ from onyx.tools.tool_implementations.web_search.utils import (
from onyx.tools.tool_implementations.web_search.utils import MAX_CHARS_PER_URL
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.url import normalize_url as normalize_web_content_url
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -801,7 +802,9 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
for url in all_urls:
doc_id = url_to_doc_id.get(url)
indexed_section = indexed_by_doc_id.get(doc_id) if doc_id else None
crawled_section = crawled_by_url.get(url)
# WebContent.link is normalized (query/fragment stripped). Match on the
# same normalized form to avoid dropping successful crawl results.
crawled_section = crawled_by_url.get(normalize_web_content_url(url))
if indexed_section and indexed_section.combined_content:
# Prefer indexed

View File

@@ -0,0 +1,260 @@
from __future__ import annotations
from typing import Any
import requests
from fastapi import HTTPException
from onyx.tools.tool_implementations.web_search.models import (
WebSearchProvider,
)
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
BRAVE_WEB_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
BRAVE_MAX_RESULTS_PER_REQUEST = 20
BRAVE_SAFESEARCH_OPTIONS = {"off", "moderate", "strict"}
BRAVE_FRESHNESS_OPTIONS = {"pd", "pw", "pm", "py"}
class RetryableBraveSearchError(Exception):
"""Error type used to trigger retry for transient Brave search failures."""
class BraveClient(WebSearchProvider):
def __init__(
self,
api_key: str,
*,
num_results: int = 10,
timeout_seconds: int = 10,
country: str | None = None,
search_lang: str | None = None,
ui_lang: str | None = None,
safesearch: str | None = None,
freshness: str | None = None,
) -> None:
if timeout_seconds <= 0:
raise ValueError("Brave provider config 'timeout_seconds' must be > 0.")
self._headers = {
"Accept": "application/json",
"X-Subscription-Token": api_key,
}
logger.debug(f"Count of results passed to BraveClient: {num_results}")
self._num_results = max(1, min(num_results, BRAVE_MAX_RESULTS_PER_REQUEST))
self._timeout_seconds = timeout_seconds
self._country = _normalize_country(country)
self._search_lang = _normalize_language_code(
search_lang, field_name="search_lang"
)
self._ui_lang = _normalize_language_code(ui_lang, field_name="ui_lang")
self._safesearch = _normalize_option(
safesearch,
field_name="safesearch",
allowed_values=BRAVE_SAFESEARCH_OPTIONS,
)
self._freshness = _normalize_option(
freshness,
field_name="freshness",
allowed_values=BRAVE_FRESHNESS_OPTIONS,
)
def _build_search_params(self, query: str) -> dict[str, str]:
params = {
"q": query,
"count": str(self._num_results),
}
if self._country:
params["country"] = self._country
if self._search_lang:
params["search_lang"] = self._search_lang
if self._ui_lang:
params["ui_lang"] = self._ui_lang
if self._safesearch:
params["safesearch"] = self._safesearch
if self._freshness:
params["freshness"] = self._freshness
return params
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(RetryableBraveSearchError,),
)
def _search_with_retries(self, query: str) -> list[WebSearchResult]:
params = self._build_search_params(query)
try:
response = requests.get(
BRAVE_WEB_SEARCH_URL,
headers=self._headers,
params=params,
timeout=self._timeout_seconds,
)
except requests.RequestException as exc:
raise RetryableBraveSearchError(
f"Brave search request failed: {exc}"
) from exc
try:
response.raise_for_status()
except requests.HTTPError as exc:
error_msg = _build_error_message(response)
if _is_retryable_status(response.status_code):
raise RetryableBraveSearchError(error_msg) from exc
raise ValueError(error_msg) from exc
data = response.json()
web_results = (data.get("web") or {}).get("results") or []
results: list[WebSearchResult] = []
for result in web_results:
if not isinstance(result, dict):
continue
link = _clean_string(result.get("url"))
if not link:
continue
title = _clean_string(result.get("title"))
description = _clean_string(result.get("description"))
results.append(
WebSearchResult(
title=title,
link=link,
snippet=description,
author=None,
published_date=None,
)
)
return results
def search(self, query: str) -> list[WebSearchResult]:
try:
return self._search_with_retries(query)
except RetryableBraveSearchError as exc:
raise ValueError(str(exc)) from exc
def test_connection(self) -> dict[str, str]:
try:
test_results = self.search("test")
if not test_results or not any(result.link for result in test_results):
raise HTTPException(
status_code=400,
detail="Brave API key validation failed: search returned no results.",
)
except HTTPException:
raise
except (ValueError, requests.RequestException) as e:
error_msg = str(e)
lower = error_msg.lower()
if (
"status 401" in lower
or "status 403" in lower
or "api key" in lower
or "auth" in lower
):
raise HTTPException(
status_code=400,
detail=f"Invalid Brave API key: {error_msg}",
) from e
if "status 429" in lower or "rate limit" in lower:
raise HTTPException(
status_code=400,
detail=f"Brave API rate limit exceeded: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"Brave API key validation failed: {error_msg}",
) from e
logger.info("Web search provider test succeeded for Brave.")
return {"status": "ok"}
def _build_error_message(response: requests.Response) -> str:
return (
"Brave search failed "
f"(status {response.status_code}): {_extract_error_detail(response)}"
)
def _extract_error_detail(response: requests.Response) -> str:
try:
payload: Any = response.json()
except Exception:
text = response.text.strip()
return text[:200] if text else "No error details"
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
detail = error.get("detail") or error.get("message")
if isinstance(detail, str):
return detail
if isinstance(error, str):
return error
message = payload.get("message")
if isinstance(message, str):
return message
return str(payload)[:200]
def _is_retryable_status(status_code: int) -> bool:
return status_code == 429 or status_code >= 500
def _clean_string(value: Any) -> str:
return value.strip() if isinstance(value, str) else ""
def _normalize_country(country: str | None) -> str | None:
if country is None:
return None
normalized = country.strip().upper()
if not normalized:
return None
if len(normalized) != 2 or not normalized.isalpha():
raise ValueError(
"Brave provider config 'country' must be a 2-letter ISO country code."
)
return normalized
def _normalize_language_code(value: str | None, *, field_name: str) -> str | None:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) > 20:
raise ValueError(f"Brave provider config '{field_name}' is too long.")
return normalized
def _normalize_option(
value: str | None,
*,
field_name: str,
allowed_values: set[str],
) -> str | None:
if value is None:
return None
normalized = value.strip().lower()
if not normalized:
return None
if normalized not in allowed_values:
allowed = ", ".join(sorted(allowed_values))
raise ValueError(
f"Brave provider config '{field_name}' must be one of: {allowed}."
)
return normalized

View File

@@ -13,6 +13,9 @@ from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
DEFAULT_MAX_PDF_SIZE_BYTES,
)
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
from onyx.tools.tool_implementations.web_search.clients.brave_client import (
BraveClient,
)
from onyx.tools.tool_implementations.web_search.clients.exa_client import (
ExaClient,
)
@@ -35,6 +38,28 @@ from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
def _parse_positive_int_config(
*,
raw_value: str | None,
default: int,
provider_name: str,
config_key: str,
) -> int:
if not raw_value:
return default
try:
value = int(raw_value)
except ValueError as exc:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be an integer."
) from exc
if value <= 0:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be greater than 0."
)
return value
def provider_requires_api_key(provider_type: WebSearchProviderType) -> bool:
"""Return True if the given provider type requires an API key.
This list is most likely just going to contain SEARXNG. The way it works is that it uses public search engines that do not
@@ -67,6 +92,22 @@ def build_search_provider_from_config(
if provider_type == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.BRAVE:
return BraveClient(
api_key=api_key,
num_results=num_results,
timeout_seconds=_parse_positive_int_config(
raw_value=config.get("timeout_seconds"),
default=10,
provider_name="Brave",
config_key="timeout_seconds",
),
country=config.get("country"),
search_lang=config.get("search_lang"),
ui_lang=config.get("ui_lang"),
safesearch=config.get("safesearch"),
freshness=config.get("freshness"),
)
if provider_type == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.GOOGLE_PSE:

View File

@@ -1,6 +1,7 @@
import base64
import hashlib
import logging
import re
import uuid
from collections.abc import Awaitable
from collections.abc import Callable
@@ -10,7 +11,9 @@ from datetime import timezone
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from fastapi.routing import APIRoute
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
@@ -76,3 +79,50 @@ def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
hash_str = base64.urlsafe_b64encode(hash_bytes).decode("utf-8").rstrip("=")
onyx_request_id = f"{prefix}:{hash_str}"
return onyx_request_id
def _build_route_map(app: FastAPI) -> list[tuple[re.Pattern[str], str]]:
"""Build a list of (compiled regex, route template) from the app's routes.
Used by endpoint context middleware to resolve request paths to route
templates, avoiding high-cardinality raw paths in metrics labels.
"""
route_map: list[tuple[re.Pattern[str], str]] = []
for route in app.routes:
if isinstance(route, APIRoute):
route_map.append((route.path_regex, route.path))
return route_map
def _match_route(route_map: list[tuple[re.Pattern[str], str]], path: str) -> str | None:
"""Match a request path against the route map and return the template."""
for pattern, template in route_map:
if pattern.match(path):
return template
return None
def add_endpoint_context_middleware(app: FastAPI) -> None:
"""Set CURRENT_ENDPOINT_CONTEXTVAR so Prometheus pool metrics can
attribute DB connections to the endpoint that checked them out.
Used by ``onyx_db_connections_held_by_endpoint`` and
``onyx_db_connection_hold_seconds`` in the pool event listeners.
Resolves request paths to route templates (e.g. /api/chat/{chat_id}
instead of /api/chat/abc-123) to keep metric label cardinality low.
Must be registered AFTER all routes are added to the app.
"""
route_map = _build_route_map(app)
@app.middleware("http")
async def set_endpoint_context(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
handler = _match_route(route_map, request.url.path)
token = CURRENT_ENDPOINT_CONTEXTVAR.set(handler or "unmatched")
try:
return await call_next(request)
finally:
CURRENT_ENDPOINT_CONTEXTVAR.reset(token)

View File

@@ -26,6 +26,12 @@ INDEX_ATTEMPT_INFO_CONTEXTVAR: contextvars.ContextVar[tuple[int, int] | None] =
contextvars.ContextVar("index_attempt_info", default=None)
)
# Set by endpoint context middleware — used for per-endpoint DB pool attribution
CURRENT_ENDPOINT_CONTEXTVAR: contextvars.ContextVar[str | None] = (
contextvars.ContextVar("current_endpoint", default=None)
)
"""Utils related to contextvars"""

View File

@@ -26,6 +26,7 @@ class WebSearchProviderType(str, Enum):
SERPER = "serper"
EXA = "exa"
SEARXNG = "searxng"
BRAVE = "brave"
class WebContentProviderType(str, Enum):

View File

@@ -2,6 +2,7 @@ import time
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import field
from dataclasses import replace
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
@@ -134,25 +135,25 @@ EXPECTED_SHARED_DRIVE_1_HIERARCHY = ExpectedHierarchyNode(
children=[
ExpectedHierarchyNode(
raw_node_id=RESTRICTED_ACCESS_FOLDER_ID,
display_name="restricted_access_folder",
display_name="restricted_access",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_1_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_1_ID,
display_name="folder_1",
display_name="folder 1",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_1_ID,
children=[
ExpectedHierarchyNode(
raw_node_id=FOLDER_1_1_ID,
display_name="folder_1_1",
display_name="folder 1-1",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_1_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_1_2_ID,
display_name="folder_1_2",
display_name="folder 1-2",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_1_ID,
),
@@ -170,25 +171,25 @@ EXPECTED_SHARED_DRIVE_2_HIERARCHY = ExpectedHierarchyNode(
children=[
ExpectedHierarchyNode(
raw_node_id=SECTIONS_FOLDER_ID,
display_name="sections_folder",
display_name="sections",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_2_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_2_ID,
display_name="folder_2",
display_name="folder 2",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_2_ID,
children=[
ExpectedHierarchyNode(
raw_node_id=FOLDER_2_1_ID,
display_name="folder_2_1",
display_name="folder 2-1",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_2_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_2_2_ID,
display_name="folder_2_2",
display_name="folder 2-2",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_2_ID,
),
@@ -208,27 +209,23 @@ def flatten_hierarchy(
return result
def _node(
raw_node_id: str,
display_name: str,
node_type: HierarchyNodeType,
raw_parent_id: str | None = None,
) -> ExpectedHierarchyNode:
return ExpectedHierarchyNode(
raw_node_id=raw_node_id,
display_name=display_name,
node_type=node_type,
raw_parent_id=raw_parent_id,
)
# Flattened maps for easy lookup
EXPECTED_SHARED_DRIVE_1_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_1_HIERARCHY)
EXPECTED_SHARED_DRIVE_2_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_2_HIERARCHY)
ALL_EXPECTED_SHARED_DRIVE_NODES = {
**EXPECTED_SHARED_DRIVE_1_NODES,
**EXPECTED_SHARED_DRIVE_2_NODES,
}
# Map of folder ID to its expected parent ID
EXPECTED_PARENT_MAPPING: dict[str, str | None] = {
SHARED_DRIVE_1_ID: None,
RESTRICTED_ACCESS_FOLDER_ID: SHARED_DRIVE_1_ID,
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
SHARED_DRIVE_2_ID: None,
SECTIONS_FOLDER_ID: SHARED_DRIVE_2_ID,
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
}
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
@@ -286,7 +283,7 @@ TEST_USER_1_MY_DRIVE_FOLDER_ID = (
)
TEST_USER_1_DRIVE_B_ID = (
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuuuuper_private
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuper_private
)
TEST_USER_1_DRIVE_B_FOLDER_ID = (
"1oIj7nigzvP5xI2F8BmibUA8R_J3AbBA-" # Child folder (silliness)
@@ -325,6 +322,106 @@ PERM_SYNC_DRIVE_ACCESS_MAPPING: dict[str, set[str]] = {
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: {ADMIN_EMAIL, TEST_USER_1_EMAIL},
}
# ============================================================================
# NON-SHARED-DRIVE HIERARCHY NODES
# ============================================================================
# These cover My Drive roots, perm sync drives, extra shared drives,
# and standalone folders that appear in various tests.
# Display names must match what the Google Drive API actually returns.
# ============================================================================
EXPECTED_FOLDER_3 = _node(
FOLDER_3_ID, "Folder 3", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
)
EXPECTED_ADMIN_MY_DRIVE = _node(ADMIN_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER)
EXPECTED_TEST_USER_1_MY_DRIVE = _node(
TEST_USER_1_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
)
EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER = _node(
TEST_USER_1_MY_DRIVE_FOLDER_ID,
"partial_sharing",
HierarchyNodeType.FOLDER,
TEST_USER_1_MY_DRIVE_ID,
)
EXPECTED_TEST_USER_2_MY_DRIVE = _node(
TEST_USER_2_MY_DRIVE, "My Drive", HierarchyNodeType.FOLDER
)
EXPECTED_TEST_USER_3_MY_DRIVE = _node(
TEST_USER_3_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
)
EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY = _node(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
"perm_sync_drive_0dc9d8b5-e243-4c2f-8678-2235958f7d7c",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A = _node(
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
"perm_sync_drive_785db121-0823-4ebe-8689-ad7f52405e32",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B = _node(
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
"perm_sync_drive_d8dc3649-3f65-4392-b87f-4b20e0389673",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_TEST_USER_1_DRIVE_B = _node(
TEST_USER_1_DRIVE_B_ID,
"My_super_special_shared_drive_suuuper_private",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_TEST_USER_1_DRIVE_B_FOLDER = _node(
TEST_USER_1_DRIVE_B_FOLDER_ID,
"silliness",
HierarchyNodeType.FOLDER,
TEST_USER_1_DRIVE_B_ID,
)
EXPECTED_TEST_USER_1_EXTRA_DRIVE_1 = _node(
TEST_USER_1_EXTRA_DRIVE_1_ID,
"Okay_Admin_fine_I_will_share",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_TEST_USER_1_EXTRA_DRIVE_2 = _node(
TEST_USER_1_EXTRA_DRIVE_2_ID, "reee test", HierarchyNodeType.SHARED_DRIVE
)
EXPECTED_TEST_USER_1_EXTRA_FOLDER = _node(
TEST_USER_1_EXTRA_FOLDER_ID,
"read only no download test",
HierarchyNodeType.FOLDER,
)
EXPECTED_PILL_FOLDER = _node(
PILL_FOLDER_ID, "pill_folder", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
)
EXPECTED_EXTERNAL_SHARED_FOLDER = _node(
EXTERNAL_SHARED_FOLDER_ID, "Onyx-test", HierarchyNodeType.FOLDER
)
# Comprehensive mapping of ALL known hierarchy nodes.
# Every retrieved node is checked against this for display_name and node_type.
ALL_EXPECTED_HIERARCHY_NODES: dict[str, ExpectedHierarchyNode] = {
**EXPECTED_SHARED_DRIVE_1_NODES,
**EXPECTED_SHARED_DRIVE_2_NODES,
FOLDER_3_ID: EXPECTED_FOLDER_3,
ADMIN_MY_DRIVE_ID: EXPECTED_ADMIN_MY_DRIVE,
TEST_USER_1_MY_DRIVE_ID: EXPECTED_TEST_USER_1_MY_DRIVE,
TEST_USER_1_MY_DRIVE_FOLDER_ID: EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER,
TEST_USER_2_MY_DRIVE: EXPECTED_TEST_USER_2_MY_DRIVE,
TEST_USER_3_MY_DRIVE_ID: EXPECTED_TEST_USER_3_MY_DRIVE,
PERM_SYNC_DRIVE_ADMIN_ONLY_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B,
TEST_USER_1_DRIVE_B_ID: EXPECTED_TEST_USER_1_DRIVE_B,
TEST_USER_1_DRIVE_B_FOLDER_ID: EXPECTED_TEST_USER_1_DRIVE_B_FOLDER,
TEST_USER_1_EXTRA_DRIVE_1_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_1,
TEST_USER_1_EXTRA_DRIVE_2_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_2,
TEST_USER_1_EXTRA_FOLDER_ID: EXPECTED_TEST_USER_1_EXTRA_FOLDER,
PILL_FOLDER_ID: EXPECTED_PILL_FOLDER,
EXTERNAL_SHARED_FOLDER_ID: EXPECTED_EXTERNAL_SHARED_FOLDER,
}
# Dictionary for access permissions
# All users have access to their own My Drive as well as public files
ACCESS_MAPPING: dict[str, list[int]] = {
@@ -508,28 +605,29 @@ def load_connector_outputs(
def assert_hierarchy_nodes_match_expected(
retrieved_nodes: list[HierarchyNode],
expected_node_ids: set[str],
expected_parent_mapping: dict[str, str | None] | None = None,
expected_nodes: dict[str, ExpectedHierarchyNode],
ignorable_node_ids: set[str] | None = None,
) -> None:
"""
Assert that retrieved hierarchy nodes match expected structure.
Checks node IDs, display names, node types, and parent relationships
for EVERY retrieved node (global checks).
Args:
retrieved_nodes: List of HierarchyNode objects from the connector
expected_node_ids: Set of expected raw_node_ids
expected_parent_mapping: Optional dict mapping node_id -> parent_id for parent verification
ignorable_node_ids: Optional set of node IDs that can be missing or extra without failing.
Useful for nodes that are non-deterministically returned by the connector.
expected_nodes: Dict mapping raw_node_id -> ExpectedHierarchyNode with
expected display_name, node_type, and raw_parent_id
ignorable_node_ids: Optional set of node IDs that can be missing or extra
without failing. Useful for non-deterministically returned nodes.
"""
expected_node_ids = set(expected_nodes.keys())
retrieved_node_ids = {node.raw_node_id for node in retrieved_nodes}
ignorable = ignorable_node_ids or set()
# Calculate differences, excluding ignorable nodes
missing = expected_node_ids - retrieved_node_ids - ignorable
extra = retrieved_node_ids - expected_node_ids - ignorable
# Print discrepancies for debugging
if missing or extra:
print("Expected hierarchy node IDs:")
print(sorted(expected_node_ids))
@@ -543,181 +641,146 @@ def assert_hierarchy_nodes_match_expected(
print("Ignorable node IDs:")
print(sorted(ignorable))
assert not missing and not extra, (
f"Hierarchy node mismatch. " f"Missing: {missing}, " f"Extra: {extra}"
)
assert (
not missing and not extra
), f"Hierarchy node mismatch. Missing: {missing}, Extra: {extra}"
# Verify parent relationships if provided
if expected_parent_mapping is not None:
for node in retrieved_nodes:
if node.raw_node_id not in expected_parent_mapping:
continue
expected_parent = expected_parent_mapping[node.raw_node_id]
assert node.raw_parent_id == expected_parent, (
for node in retrieved_nodes:
if node.raw_node_id in ignorable and node.raw_node_id not in expected_nodes:
continue
assert (
node.raw_node_id in expected_nodes
), f"Node {node.raw_node_id} ({node.display_name}) not found in expected_nodes"
expected = expected_nodes[node.raw_node_id]
assert node.display_name == expected.display_name, (
f"Display name mismatch for node {node.raw_node_id}: "
f"expected '{expected.display_name}', got '{node.display_name}'"
)
assert node.node_type == expected.node_type, (
f"Node type mismatch for node {node.raw_node_id}: "
f"expected '{expected.node_type}', got '{node.node_type}'"
)
if expected.raw_parent_id is not None:
assert node.raw_parent_id == expected.raw_parent_id, (
f"Parent mismatch for node {node.raw_node_id} ({node.display_name}): "
f"expected parent={expected_parent}, got parent={node.raw_parent_id}"
f"expected parent={expected.raw_parent_id}, got parent={node.raw_parent_id}"
)
def _pick(
*node_ids: str,
) -> dict[str, ExpectedHierarchyNode]:
"""Pick nodes from ALL_EXPECTED_HIERARCHY_NODES by their IDs."""
return {nid: ALL_EXPECTED_HIERARCHY_NODES[nid] for nid in node_ids}
def _clear_parents(
nodes: dict[str, ExpectedHierarchyNode],
*node_ids: str,
) -> dict[str, ExpectedHierarchyNode]:
"""Return a shallow copy of nodes with the specified nodes' parents set to None.
Useful for OAuth tests where the user can't resolve certain parents
(e.g. a folder in another user's My Drive)."""
result = dict(nodes)
for nid in node_ids:
result[nid] = replace(result[nid], raw_parent_id=None)
return result
def get_expected_hierarchy_for_shared_drives(
include_drive_1: bool = True,
include_drive_2: bool = True,
include_restricted_folder: bool = True,
) -> tuple[set[str], dict[str, str | None]]:
"""
Get expected hierarchy node IDs and parent mapping for shared drives.
Returns:
Tuple of (expected_node_ids, expected_parent_mapping)
"""
expected_ids: set[str] = set()
expected_parents: dict[str, str | None] = {}
) -> dict[str, ExpectedHierarchyNode]:
"""Get expected hierarchy nodes for shared drives."""
result: dict[str, ExpectedHierarchyNode] = {}
if include_drive_1:
expected_ids.add(SHARED_DRIVE_1_ID)
expected_parents[SHARED_DRIVE_1_ID] = None
if include_restricted_folder:
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_parents[RESTRICTED_ACCESS_FOLDER_ID] = SHARED_DRIVE_1_ID
expected_ids.add(FOLDER_1_ID)
expected_parents[FOLDER_1_ID] = SHARED_DRIVE_1_ID
expected_ids.add(FOLDER_1_1_ID)
expected_parents[FOLDER_1_1_ID] = FOLDER_1_ID
expected_ids.add(FOLDER_1_2_ID)
expected_parents[FOLDER_1_2_ID] = FOLDER_1_ID
result.update(EXPECTED_SHARED_DRIVE_1_NODES)
if not include_restricted_folder:
result.pop(RESTRICTED_ACCESS_FOLDER_ID, None)
if include_drive_2:
expected_ids.add(SHARED_DRIVE_2_ID)
expected_parents[SHARED_DRIVE_2_ID] = None
result.update(EXPECTED_SHARED_DRIVE_2_NODES)
expected_ids.add(SECTIONS_FOLDER_ID)
expected_parents[SECTIONS_FOLDER_ID] = SHARED_DRIVE_2_ID
expected_ids.add(FOLDER_2_ID)
expected_parents[FOLDER_2_ID] = SHARED_DRIVE_2_ID
expected_ids.add(FOLDER_2_1_ID)
expected_parents[FOLDER_2_1_ID] = FOLDER_2_ID
expected_ids.add(FOLDER_2_2_ID)
expected_parents[FOLDER_2_2_ID] = FOLDER_2_ID
return expected_ids, expected_parents
return result
def get_expected_hierarchy_for_folder_1() -> tuple[set[str], dict[str, str | None]]:
def get_expected_hierarchy_for_folder_1() -> dict[str, ExpectedHierarchyNode]:
"""Get expected hierarchy for folder_1 and its children only."""
return (
{FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID},
{
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
},
)
return _pick(FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID)
def get_expected_hierarchy_for_folder_2() -> tuple[set[str], dict[str, str | None]]:
def get_expected_hierarchy_for_folder_2() -> dict[str, ExpectedHierarchyNode]:
"""Get expected hierarchy for folder_2 and its children only."""
return (
{FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID},
{
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
},
)
return _pick(FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID)
def get_expected_hierarchy_for_test_user_1() -> tuple[set[str], dict[str, str | None]]:
def get_expected_hierarchy_for_test_user_1() -> dict[str, ExpectedHierarchyNode]:
"""
Get expected hierarchy for test_user_1's full access.
Get expected hierarchy for test_user_1's full access (OAuth).
test_user_1 has access to:
- shared_drive_1 and its contents (folder_1, folder_1_1, folder_1_2)
- folder_3 (shared from admin's My Drive)
- PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A and PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B
- Additional drives/folders the user has access to
NOTE: Folder 3 lives in the admin's My Drive. When running as an OAuth
connector for test_user_1, the Google Drive API won't return the parent
for Folder 3 because the user can't access the admin's My Drive root.
"""
# Start with shared_drive_1 hierarchy
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
result = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=False,
include_restricted_folder=False,
)
# folder_3 is shared from admin's My Drive
expected_ids.add(FOLDER_3_ID)
# Perm sync drives that test_user_1 has access to
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID] = None
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID] = None
# Additional drives/folders test_user_1 has access to
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
expected_parents[TEST_USER_1_MY_DRIVE_ID] = None
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
expected_parents[TEST_USER_1_MY_DRIVE_FOLDER_ID] = TEST_USER_1_MY_DRIVE_ID
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
expected_parents[TEST_USER_1_DRIVE_B_ID] = None
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
expected_parents[TEST_USER_1_DRIVE_B_FOLDER_ID] = TEST_USER_1_DRIVE_B_ID
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_parents[TEST_USER_1_EXTRA_DRIVE_1_ID] = None
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_parents[TEST_USER_1_EXTRA_DRIVE_2_ID] = None
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
# Parent unknown, skip adding to expected_parents
return expected_ids, expected_parents
result.update(
_pick(
FOLDER_3_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
TEST_USER_1_DRIVE_B_ID,
TEST_USER_1_DRIVE_B_FOLDER_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
)
)
return _clear_parents(result, FOLDER_3_ID)
def get_expected_hierarchy_for_test_user_1_shared_drives_only() -> (
tuple[set[str], dict[str, str | None]]
dict[str, ExpectedHierarchyNode]
):
"""Expected hierarchy nodes when test_user_1 runs with include_shared_drives=True only."""
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
# This mode should not include My Drive roots/folders.
expected_ids.discard(TEST_USER_1_MY_DRIVE_ID)
expected_ids.discard(TEST_USER_1_MY_DRIVE_FOLDER_ID)
# don't include shared with me
expected_ids.discard(FOLDER_3_ID)
expected_ids.discard(TEST_USER_1_EXTRA_FOLDER_ID)
return expected_ids, expected_parents
result = get_expected_hierarchy_for_test_user_1()
for nid in (
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
FOLDER_3_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
):
result.pop(nid, None)
return result
def get_expected_hierarchy_for_test_user_1_shared_with_me_only() -> (
tuple[set[str], dict[str, str | None]]
dict[str, ExpectedHierarchyNode]
):
"""Expected hierarchy nodes when test_user_1 runs with include_files_shared_with_me=True only."""
expected_ids: set[str] = {FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID}
expected_parents: dict[str, str | None] = {}
return expected_ids, expected_parents
return _clear_parents(
_pick(FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID),
FOLDER_3_ID,
)
def get_expected_hierarchy_for_test_user_1_my_drive_only() -> (
tuple[set[str], dict[str, str | None]]
dict[str, ExpectedHierarchyNode]
):
"""Expected hierarchy nodes when test_user_1 runs with include_my_drives=True only."""
expected_ids: set[str] = {TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID}
expected_parents: dict[str, str | None] = {
TEST_USER_1_MY_DRIVE_ID: None,
TEST_USER_1_MY_DRIVE_FOLDER_ID: TEST_USER_1_MY_DRIVE_ID,
}
return expected_ids, expected_parents
return _pick(TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID)

View File

@@ -3,12 +3,11 @@ from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
ADMIN_MY_DRIVE_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
@@ -16,21 +15,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
assert_hierarchy_nodes_match_expected,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
@@ -47,18 +40,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
from tests.daily.connectors.google_drive.consts_and_utils import (
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
PILL_FOLDER_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import PILL_FOLDER_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
RESTRICTED_ACCESS_FOLDER_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
TEST_USER_1_EXTRA_DRIVE_1_ID,
)
@@ -90,7 +80,6 @@ def test_include_all(
)
output = load_connector_outputs(connector)
# Should get everything in shared and admin's My Drive with oauth
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
@@ -109,33 +98,28 @@ def test_include_all(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes for shared drives
# When include_shared_drives=True, we get ALL shared drives the admin has access to
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
# Restricted folder may not always be retrieved due to access limitations
include_restricted_folder=False,
)
# Add additional shared drives that admin has access to
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(ADMIN_MY_DRIVE_ID)
expected_ids.add(PILL_FOLDER_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
# My Drive folders
expected_ids.add(FOLDER_3_ID)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
ADMIN_MY_DRIVE_ID,
PILL_FOLDER_ID,
RESTRICTED_ACCESS_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
FOLDER_3_ID,
)
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
expected_nodes=expected_nodes,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -160,7 +144,6 @@ def test_include_shared_drives_only(
)
output = load_connector_outputs(connector)
# Should only get shared drives
expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
@@ -177,26 +160,24 @@ def test_include_shared_drives_only(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - should include both shared drives and their folders
# When include_shared_drives=True, we get ALL shared drives admin has access to
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
# Add additional shared drives that admin has access to
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
RESTRICTED_ACCESS_FOLDER_ID,
)
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
expected_nodes=expected_nodes,
)
@@ -220,24 +201,21 @@ def test_include_my_drives_only(
)
output = load_connector_outputs(connector)
# Should only get primary_admins My Drive because we are impersonating them
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
assert_expected_docs_in_retrieved_docs(
retrieved_docs=output.documents,
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - My Drive should yield folder_3 as a hierarchy node
# Also includes admin's My Drive root and folders shared with admin
expected_ids = {
expected_nodes = _pick(
FOLDER_3_ID,
ADMIN_MY_DRIVE_ID,
PILL_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
}
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_nodes=expected_nodes,
)
@@ -273,17 +251,14 @@ def test_drive_one_only(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=False,
include_restricted_folder=False,
)
# Restricted folder is non-deterministically returned by the connector
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
expected_nodes=expected_nodes,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -324,33 +299,15 @@ def test_folder_and_shared_drive(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
expected_ids = {
SHARED_DRIVE_1_ID,
FOLDER_1_ID,
FOLDER_1_1_ID,
FOLDER_1_2_ID,
SHARED_DRIVE_2_ID,
FOLDER_2_ID,
FOLDER_2_1_ID,
FOLDER_2_2_ID,
}
expected_parents = {
SHARED_DRIVE_1_ID: None,
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
SHARED_DRIVE_2_ID: None,
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
}
# Restricted folder is non-deterministically returned
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
expected_nodes=expected_nodes,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -370,7 +327,6 @@ def test_folders_only(
FOLDER_2_2_URL,
FOLDER_3_URL,
]
# This should get converted to a drive request and spit out a warning in the logs
shared_drive_urls = [
FOLDER_1_1_URL,
]
@@ -397,23 +353,16 @@ def test_folders_only(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - specific folders requested plus their parent nodes
# The connector walks up the hierarchy to include parent drives/folders
expected_ids = {
SHARED_DRIVE_1_ID,
FOLDER_1_ID,
FOLDER_1_1_ID,
FOLDER_1_2_ID,
SHARED_DRIVE_2_ID,
FOLDER_2_ID,
FOLDER_2_1_ID,
FOLDER_2_2_ID,
ADMIN_MY_DRIVE_ID,
FOLDER_3_ID,
}
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_nodes=expected_nodes,
)
@@ -446,9 +395,8 @@ def test_personal_folders_only(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - folder_3 and its parent (admin's My Drive root)
expected_ids = {FOLDER_3_ID, ADMIN_MY_DRIVE_ID}
expected_nodes = _pick(FOLDER_3_ID, ADMIN_MY_DRIVE_ID)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_nodes=expected_nodes,
)

View File

@@ -14,11 +14,10 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import (
ADMIN_MY_DRIVE_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_hierarchy_nodes_match_expected,
)
@@ -262,37 +261,35 @@ def test_gdrive_perm_sync_with_real_data(
hierarchy_connector = _build_connector(google_drive_service_acct_connector_factory)
output = load_connector_outputs(hierarchy_connector, include_permissions=True)
# Verify the expected shared drives hierarchy
# When include_shared_drives=True and include_my_drives=True, we get ALL drives
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
# Add additional shared drives in the organization
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(ADMIN_MY_DRIVE_ID)
expected_ids.add(TEST_USER_2_MY_DRIVE)
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
expected_ids.add(PILL_FOLDER_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
expected_ids.add(FOLDER_3_ID)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
TEST_USER_1_DRIVE_B_ID,
TEST_USER_1_DRIVE_B_FOLDER_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
ADMIN_MY_DRIVE_ID,
TEST_USER_2_MY_DRIVE,
TEST_USER_3_MY_DRIVE_ID,
PILL_FOLDER_ID,
RESTRICTED_ACCESS_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
EXTERNAL_SHARED_FOLDER_ID,
FOLDER_3_ID,
)
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
expected_nodes=expected_nodes,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)

Some files were not shown because too many files have changed in this diff Show More