Compare commits

..

9 Commits

Author SHA1 Message Date
Raunak Bhagat
c20e5205ca feat(opal): Add AuxiliaryTag component and resync colors.css with Figma
Add AuxiliaryTag component (green/blue/purple/amber/gray) with icon support.

Resync all color tokens in colors.css against Figma source of truth (20 fixes):
- Fix neon alpha variant naming: -60/-30 → -a60/-a30 to disambiguate from
  Figma scale levels (e.g. --neon-amber-60 is now scale Neon/Amber/60,
  --neon-amber-a60 is the 40-at-60%-opacity highlight variant)
- Add neon scale primitives for yellow, lime, cyan, sky, magenta (50, 20, 05
  for light mode; 80, 90 for dark mode)
- Fix all neon-based theme tokens (yellow, lime, cyan, sky, magenta) from
  alpha variants to correct Figma solid swatches
- Fix background-neutral-inverted-04 (grey-75→grey-60) and -03 (grey-80→grey-75)
- Fix background-tint-inverted-04 (tint-80→tint-60)
- Fix theme-gradient-00 (grey-00→grey-100)
- Fix mask-01 (alpha-grey-100→alpha-grey-00)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 00:11:54 -08:00
Raunak Bhagat
19151c2c44 fix(opal): Auto-grow LabelLayout edit input to match content width
- Replace flex-1 input with inline-grid sizer pattern: a hidden mirror
  span and the input share the same grid cell, so the input grows
  horizontally as content is typed
- Set input size=1 to eliminate browser default intrinsic width
- Accessories (Optional, auxIcon) stay beside the input during editing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 23:37:27 -08:00
Raunak Bhagat
72b1771bc2 feat(opal): Add auxIcon accessory to LabelLayout
- Add `auxIcon?: "info-gray" | "info-blue" | "warning" | "error"` prop
- Renders a status icon beside the title with p-0.5 padding
  (icon size = lineHeight - 4px, auto-scales per preset)
- Icon/color mapping: info-gray (AlertCircle/text-02),
  info-blue (AlertCircle/status-info-05), warning (AlertTriangle/status-warning-05),
  error (XOctagon/status-error-05)
- Title row order: [title, (Optional), aux-icon, edit-button]
- Update storybook with auxIcon examples and combined accessories

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 23:29:42 -08:00
Raunak Bhagat
8778266521 feat(opal): Add optional indicator accessory to LabelLayout
- Add `optional?: boolean` prop to LabelContentProps (LabelLayout only)
- Renders "(Optional)" beside the title in the muted font variant with text-03
- Muted font mapping: main-content → font-main-content-muted,
  main-ui → font-main-ui-muted, secondary → font-secondary-action (no muted variant)
- Update storybook with optional indicator examples for all LabelLayout presets

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 23:13:41 -08:00
Raunak Bhagat
a8cbba86b4 feat(opal): Add BodyLayout component with orientation and prominence
- New BodyLayout for main-content/main-ui/secondary presets with body variant
- Three orientations: inline (icon left), vertical (icon top), reverse (title left)
- Two prominences: default (text-04) and muted (text-03)
- Read-only layout — no editing or descriptions supported
- Wire up Content router to dispatch variant="body" to BodyLayout

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:54:26 -08:00
Raunak Bhagat
fb21c6cae5 refactor(opal): Inline presets, optional props, and edit UX improvements
- Delete presets.ts; inline HeadingLayout config into HeadingLayout.tsx
  and shared types into components.tsx (eliminates confusing unused entries)
- Make LabelLayout description optional (was required)
- Auto-select text when entering edit mode in both HeadingLayout and LabelLayout
- Update README with separate HeadingLayout/LabelLayout preset tables

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:25:23 -08:00
Raunak Bhagat
922154f3d6 feat(opal): Add LabelLayout component and 2xs button size
Add LabelLayout for main-content/main-ui/secondary presets with
mandatory description, per-preset icon color, and editable support.
Add 2xs interactive container size (1rem/16px) with mini rounding
variant to support secondary-scale edit buttons without layout flash.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:03:03 -08:00
Raunak Bhagat
5e9ea9edef refactor(opal): Rename LineItemLayout to Content with two-axis architecture
Restructure the component into a Content router that dispatches to
internal layout components based on sizePreset and variant axes.
Implement HeadingLayout with parameterized sizing from presets config,
scaled edit button sizes to prevent layout flash on edit toggle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 20:08:47 -08:00
Raunak Bhagat
dfc31e5d37 feat(opal): Add LineItemLayout component with editable headline variant
Introduces the LineItemLayout component in the opal design system library,
matching the Figma "Content Container" spec. Supports headline variant with
icon placement (left/top) and inline title editing with hover-revealed edit button.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 18:33:05 -08:00
420 changed files with 8769 additions and 20833 deletions

View File

@@ -1 +0,0 @@
../.cursor/skills

View File

@@ -1,248 +0,0 @@
---
name: playwright-e2e-tests
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
---
# Playwright E2E Tests
## Project Layout
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
- **Config**: `web/playwright.config.ts`
- **Utilities**: `web/tests/e2e/utils/`
- **Constants**: `web/tests/e2e/constants.ts`
- **Global setup**: `web/tests/e2e/global-setup.ts`
- **Output**: `web/output/playwright/`
## Imports
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
```
All new files should be `.ts`, not `.js`.
## Running Tests
```bash
# Run a specific test file
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
# Run a specific project
npx playwright test --project admin
npx playwright test --project exclusive
```
## Test Projects
| Project | Description | Parallelism |
|---------|-------------|-------------|
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
## Authentication
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
- Server readiness check (polls health endpoint, 60s timeout)
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
- Setting display name to `"worker"` for each worker user
- Promoting admin2 to admin role
- Ensuring a public LLM provider exists
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
When a test needs a different user, use API-based login — never drive the login UI:
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAs(page, "admin2");
// Log in as the worker-specific user (preferred for test isolation):
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
```
## Test Structure
Tests start pre-authenticated as admin — navigate and test directly:
```typescript
import { test, expect } from "@playwright/test";
test.describe("Feature Name", () => {
test("should describe expected behavior clearly", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Already authenticated as admin — go straight to testing
});
});
```
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
```typescript
import { test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
});
```
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
## Key Utilities
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
Backend API client for test setup/teardown. Key methods:
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
- **Chat**: `createChatSession()`, `deleteChatSession()`
### `chatActions` (`@tests/e2e/utils/chatActions`)
- `sendMessage(page, message)` — sends a message and waits for AI response
- `startNewChat(page)` — clicks new-chat button and waits for intro
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
- `switchModel(page, modelName)` — switches LLM model via popover
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
- `expectElementScreenshot(locator, { name, mask?, hide? })`
- Controlled by `VISUAL_REGRESSION=true` env var
### `theme` (`@tests/e2e/utils/theme`)
- `THEMES``["light", "dark"] as const` array for iterating over both themes
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
```typescript
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
for (const theme of THEMES) {
test.describe(`Feature (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("renders correctly", async ({ page }) => {
await page.goto("/app");
await expectScreenshot(page, { name: `feature-${theme}` });
});
});
}
```
### `tools` (`@tests/e2e/utils/tools`)
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
- `openActionManagement(page)` — opens the tool management popover
## Locator Strategy
Use locators in this priority order:
1. **`data-testid` / `aria-label`** — preferred for Onyx components
```typescript
page.getByTestId("AppSidebar/new-session")
page.getByLabel("admin-page-title")
```
2. **Role-based** — for standard HTML elements
```typescript
page.getByRole("button", { name: "Create" })
page.getByRole("dialog")
```
3. **Text/Label** — for visible text content
```typescript
page.getByText("Custom Assistant")
page.getByLabel("Email")
```
4. **CSS selectors** — last resort, only when above won't work
```typescript
page.locator('input[name="name"]')
page.locator("#onyx-chat-input-textarea")
```
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
## Assertions
Use web-first assertions — they auto-retry until the condition is met:
```typescript
// Visibility
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
// Text content
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
// Count
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
// URL
await expect(page).toHaveURL(/chatId=/);
// Element state
await expect(toggle).toBeChecked();
await expect(button).toBeEnabled();
```
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
## Waiting Strategy
```typescript
// Wait for load state after navigation
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Wait for specific element
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
// Wait for URL change
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
// Wait for network response
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
```
## Best Practices
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -0,0 +1,151 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,10 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -127,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -41,7 +41,8 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
with:
uv_version: "0.9.9"
@@ -91,7 +92,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -22,9 +22,6 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
@@ -303,7 +300,6 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
@@ -593,10 +589,7 @@ jobs:
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: >-
always() &&
github.event_name == 'pull_request' &&
needs.playwright-tests.result != 'cancelled'
if: always() && github.event_name == 'pull_request'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:

View File

@@ -5,8 +5,6 @@ on:
branches: ["main"]
pull_request:
branches: ["**"]
paths:
- ".github/**"
permissions: {}
@@ -23,18 +21,29 @@ jobs:
with:
persist-credentials: false
- name: Detect changes
id: filter
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
with:
filters: |
zizmor:
- '.github/**'
- name: Install the latest version of uv
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Run zizmor
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif

1
.gitignore vendored
View File

@@ -7,7 +7,6 @@
.zed
.cursor
!/.cursor/mcp.json
!/.cursor/skills/
# macos
.DS_store

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,32 +0,0 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -1,31 +0,0 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -263,15 +263,9 @@ def refresh_license_cache(
try:
payload = verify_license_signature(license_record.license_data)
# Derive source from payload: manual licenses lack stripe_customer_id
source: LicenseSource = (
LicenseSource.AUTO_FETCH
if payload.stripe_customer_id
else LicenseSource.MANUAL_UPLOAD
)
return update_license_cache(
payload,
source=source,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:

View File

@@ -1,604 +0,0 @@
"""SCIM Data Access Layer.
All database operations for SCIM provisioning — token management, user
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
Usage from FastAPI::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.post("/tokens")
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
token = dal.create_token(name=..., hashed_token=..., ...)
dal.commit()
return token
Usage from background tasks::
with ScimDAL.from_tenant("tenant_abc") as dal:
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
dal.commit()
"""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ScimDAL(DAL):
"""Data Access Layer for SCIM provisioning operations.
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
when you want to persist changes. This follows the existing ``_no_commit``
convention and lets callers batch multiple operations into one transaction.
"""
# ------------------------------------------------------------------
# Token operations
# ------------------------------------------------------------------
def create_token(
self,
name: str,
hashed_token: str,
token_display: str,
created_by_id: UUID,
) -> ScimToken:
"""Create a new SCIM bearer token.
Only one token is active at a time — this method automatically revokes
all existing active tokens before creating the new one.
"""
# Revoke any currently active tokens
active_tokens = list(
self._session.scalars(
select(ScimToken).where(ScimToken.is_active.is_(True))
).all()
)
for t in active_tokens:
t.is_active = False
token = ScimToken(
name=name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=created_by_id,
)
self._session.add(token)
self._session.flush()
return token
def get_active_token(self) -> ScimToken | None:
"""Return the single currently active token, or None."""
return self._session.scalar(
select(ScimToken).where(ScimToken.is_active.is_(True))
)
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
"""Look up a token by its SHA-256 hash."""
return self._session.scalar(
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
)
def revoke_token(self, token_id: int) -> None:
"""Deactivate a token by ID.
Raises:
ValueError: If the token does not exist.
"""
token = self._session.get(ScimToken, token_id)
if not token:
raise ValueError(f"SCIM token with id {token_id} not found")
token.is_active = False
def update_token_last_used(self, token_id: int) -> None:
"""Update the last_used_at timestamp for a token."""
token = self._session.get(ScimToken, token_id)
if token:
token.last_used_at = func.now() # type: ignore[assignment]
# ------------------------------------------------------------------
# User mapping operations
# ------------------------------------------------------------------
def create_user_mapping(
self,
external_id: str,
user_id: UUID,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_user_mapping_by_external_id(
self, external_id: str
) -> ScimUserMapping | None:
"""Look up a user mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
)
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
"""Look up a user mapping by the Onyx user ID."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
)
def list_user_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimUserMapping], int]:
"""List user mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimUserMapping)
.order_by(ScimUserMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def update_user_mapping_external_id(
self,
mapping_id: int,
external_id: str,
) -> ScimUserMapping:
"""Update the external ID on a user mapping.
Raises:
ValueError: If the mapping does not exist.
"""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
mapping.external_id = external_id
return mapping
def delete_user_mapping(self, mapping_id: int) -> None:
"""Delete a user mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# User query operations
# ------------------------------------------------------------------
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
"""Fetch a user by email (case-insensitive)."""
return self._session.scalar(
select(User).where(func.lower(User.email) == func.lower(email))
)
def add_user(self, user: User) -> None:
"""Add a new user to the session and flush to assign an ID."""
self._session.add(user)
self._session.flush()
def update_user(
self,
user: User,
*,
email: str | None = None,
is_active: bool | None = None,
personal_name: str | None = None,
) -> None:
"""Update user attributes. Only sets fields that are provided."""
if email is not None:
user.email = email
if is_active is not None:
user.is_active = is_active
if personal_name is not None:
user.personal_name = personal_name
def deactivate_user(self, user: User) -> None:
"""Mark a user as inactive."""
user.is_active = False
def list_users(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
# Count total matching rows first, then paginate. SCIM uses 1-based
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
# ------------------------------------------------------------------
# Group mapping operations
# ------------------------------------------------------------------
def create_group_mapping(
self,
external_id: str,
user_group_id: int,
) -> ScimGroupMapping:
"""Create a mapping between a SCIM externalId and an Onyx user group."""
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_group_mapping_by_external_id(
self, external_id: str
) -> ScimGroupMapping | None:
"""Look up a group mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
)
def get_group_mapping_by_group_id(
self, user_group_id: int
) -> ScimGroupMapping | None:
"""Look up a group mapping by the Onyx user group ID."""
return self._session.scalar(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id == user_group_id
)
)
def list_group_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimGroupMapping], int]:
"""List group mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimGroupMapping)
.order_by(ScimGroupMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def delete_group_mapping(self, mapping_id: int) -> None:
"""Delete a group mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group query operations
# ------------------------------------------------------------------
def get_group(self, group_id: int) -> UserGroup | None:
"""Fetch a group by ID, returning None if deleted or missing."""
group = self._session.get(UserGroup, group_id)
if group and group.is_up_for_deletion:
return None
return group
def get_group_by_name(self, name: str) -> UserGroup | None:
"""Fetch a group by exact name."""
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
def add_group(self, group: UserGroup) -> None:
"""Add a new group to the session and flush to assign an ID."""
self._session.add(group)
self._session.flush()
def update_group(
self,
group: UserGroup,
*,
name: str | None = None,
) -> None:
"""Update group attributes and set the modification timestamp."""
if name is not None:
group.name = name
group.time_last_modified_by_user = func.now()
def delete_group(self, group: UserGroup) -> None:
"""Delete a group from the session."""
self._session.delete(group)
def list_groups(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[UserGroup, str | None]], int]:
"""Query groups with optional SCIM filter and pagination.
Returns:
A tuple of (list of (group, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(UserGroup.id == mapping.user_group_id)
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
groups = list(
self._session.scalars(
query.order_by(UserGroup.id).offset(offset).limit(count)
).all()
)
ext_id_map = self._get_group_external_ids([g.id for g in groups])
return [(g, ext_id_map.get(g.id)) for g in groups], total
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
"""Get group members as (user_id, email) pairs."""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
).all()
user_ids = [r.user_id for r in rels if r.user_id]
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users_by_id = {u.id: u for u in users}
return [
(
r.user_id,
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
)
for r in rels
if r.user_id
]
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
"""Return the subset of UUIDs that don't exist as users.
Returns an empty list if all IDs are valid.
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Add user-group relationships, ignoring duplicates."""
if not user_ids:
return
self._session.execute(
pg_insert(User__UserGroup)
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
.on_conflict_do_nothing(
index_elements=[
User__UserGroup.user_group_id,
User__UserGroup.user_id,
]
)
)
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Replace all members of a group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
)
self.upsert_group_members(group_id, user_ids)
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Remove specific members from a group."""
if not user_ids:
return
self._session.execute(
sa_delete(User__UserGroup).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.in_(user_ids),
)
)
def delete_group_with_members(self, group: UserGroup) -> None:
"""Remove all member relationships and delete the group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
)
self._session.delete(group)
def sync_group_external_id(
self, group_id: int, new_external_id: str | None
) -> None:
"""Create, update, or delete the external ID mapping for a group."""
mapping = self.get_group_mapping_by_group_id(group_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_group_mapping(
external_id=new_external_id, user_group_id=group_id
)
elif mapping:
self.delete_group_mapping(mapping.id)
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
"""Batch-fetch external IDs for a list of group IDs."""
if not group_ids:
return {}
mappings = self._session.scalars(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id.in_(group_ids)
)
).all()
return {m.user_group_id: m.external_id for m in mappings}
# ---------------------------------------------------------------------------
# Module-level helpers (used by DAL methods above)
# ---------------------------------------------------------------------------
def _apply_scim_string_op(
query: Select[tuple[User]] | Select[tuple[UserGroup]],
column: SQLColumnExpression[str],
scim_filter: ScimFilter,
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
SQLAlchemy's operators handle LIKE-pattern escaping internally.
"""
val = scim_filter.value
if scim_filter.operator == ScimFilterOperator.EQUAL:
return query.where(func.lower(column) == val.lower())
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
return query.where(column.icontains(val, autoescape=True))
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
return query.where(column.istartswith(val, autoescape=True))
else:
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")

View File

@@ -9,7 +9,6 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -266,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -275,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(

View File

@@ -1,13 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -47,27 +43,14 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,13 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -18,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -163,11 +162,6 @@ def get_application() -> FastAPI:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
# before bearer token configuration is complete
("/scim/v2/ServiceProviderConfig", {"GET"}),
("/scim/v2/ResourceTypes", {"GET"}),
("/scim/v2/Schemas", {"GET"}),
# needs to be accessible prior to user login
("/enterprise-settings", {"GET"}),
("/enterprise-settings/logo", {"GET"}),

View File

@@ -13,7 +13,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
@@ -203,63 +198,3 @@ def upload_custom_analytics_script(
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()
# ---------------------------------------------------------------------------
# SCIM token management
# ---------------------------------------------------------------------------
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
return ScimTokenResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
)
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.
Only one token is active at a time — creating a new token automatically
revokes all previous tokens. The raw token value is returned exactly once
in the response; it cannot be retrieved again.
"""
raw_token, hashed_token, token_display = generate_scim_token()
token = dal.create_token(
name=body.name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=user.id,
)
dal.commit()
return ScimTokenCreatedResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
raw_token=raw_token,
)

View File

@@ -1,689 +0,0 @@
"""SCIM 2.0 API endpoints (RFC 7644).
This module provides the FastAPI router for SCIM service discovery,
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
these endpoints to provision and manage users and groups.
Service discovery endpoints are unauthenticated — IdPs may probe them
before bearer token configuration is complete. All other endpoints
require a valid SCIM bearer token.
"""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@scim_router.get("/ServiceProviderConfig")
def get_service_provider_config() -> ScimServiceProviderConfig:
"""Advertise supported SCIM features (RFC 7643 §5)."""
return SERVICE_PROVIDER_CONFIG
@scim_router.get("/ResourceTypes")
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
@scim_router.get("/Schemas")
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> JSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
body = ScimError(status=str(status), detail=detail)
return JSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
except ValueError:
return _scim_error_response(404, f"User {user_id} not found")
user = dal.get_user(uid)
if not user:
return _scim_error_response(404, f"User {user_id} not found")
return user
def _scim_name_to_str(name: ScimName | None) -> str | None:
"""Extract a display name string from a SCIM name object.
Returns None if no name is provided, so the caller can decide
whether to update the user's personal_name.
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# ---------------------------------------------------------------------------
# User CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
)
try:
dal.add_user(user)
except IntegrityError:
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
dal.commit()
return _user_to_scim(user, external_id)
@scim_router.put("/Users/{user_id}", response_model=None)
def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
# Handle activation (need seat check) / deactivation
if user_resource.active and not user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
dal.commit()
return _user_to_scim(user, new_external_id)
@scim_router.patch("/Users/{user_id}", response_model=None)
def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
``PATCH {"active": false}`` rather than DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current = _user_to_scim(user, external_id)
try:
patched = apply_user_patch(patch_request.Operations, current)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.commit()
return _user_to_scim(user, patched.externalId)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
dal.deactivate_user(user)
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
dal.commit()
return Response(status_code=204)
# ---------------------------------------------------------------------------
# Group helpers
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
except ValueError:
return _scim_error_response(404, f"Group {group_id} not found")
group = dal.get_group(gid)
if not group:
return _scim_error_response(404, f"Group {group_id} not found")
return group
def _parse_member_uuids(
members: list[ScimGroupMember],
) -> tuple[list[UUID], str | None]:
"""Parse member value strings to UUIDs.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids: list[UUID] = []
for m in members:
try:
uuids.append(UUID(m.value))
except ValueError:
return [], f"Invalid member ID: {m.value}"
return uuids, None
def _validate_and_parse_members(
members: list[ScimGroupMember], dal: ScimDAL
) -> tuple[list[UUID], str | None]:
"""Parse and validate member UUIDs exist in the database.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids, err = _parse_member_uuids(members)
if err:
return [], err
if uuids:
missing = dal.validate_member_ids(uuids)
if missing:
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
return uuids, None
# ---------------------------------------------------------------------------
# Group CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
db_group = UserGroup(
name=group_resource.displayName,
is_up_to_date=True,
time_last_modified_by_user=func.now(),
)
try:
dal.add_group(db_group)
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
dal.upsert_group_members(db_group.id, member_uuids)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
dal.update_group(group, name=new_name)
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
missing = dal.validate_member_ids(add_uuids)
if missing:
return _scim_error_response(
400,
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
dal.commit()
return Response(status_code=204)
def _is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
try:
UUID(value)
return True
except ValueError:
return False

View File

@@ -1,104 +0,0 @@
"""SCIM bearer token authentication.
SCIM endpoints are authenticated via bearer tokens that admins create in the
Onyx UI. This module provides:
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
validates the token from the Authorization header.
- ``generate_scim_token``: Creates a new cryptographically random token
and returns the raw value, its SHA-256 hash, and a display suffix.
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
URL-safe base64 from ``secrets.token_urlsafe``.
The hash is stored in the ``scim_token`` table; the raw value is shown to
the admin exactly once at creation time.
"""
import hashlib
import secrets
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
def _hash_scim_token(token: str) -> str:
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def generate_scim_token() -> tuple[str, str, str]:
"""Generate a new SCIM bearer token.
Returns:
A tuple of ``(raw_token, hashed_token, token_display)`` where
``token_display`` is a masked version showing only the last 4 chars.
"""
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
hashed_token = _hash_scim_token(raw_token)
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
return raw_token, hashed_token, token_display
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
"""Extract and hash a SCIM token from the request Authorization header."""
return get_hashed_bearer_token_from_request(
request,
valid_prefixes=[SCIM_TOKEN_PREFIX],
hash_fn=_hash_scim_token,
)
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
def verify_scim_token(
request: Request,
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimToken:
"""FastAPI dependency that authenticates SCIM requests.
Extracts the bearer token from the Authorization header, hashes it,
looks it up in the database, and verifies it is active.
Note:
This dependency does NOT update ``last_used_at`` — the endpoint
should do that via ``ScimDAL.update_token_last_used()`` so the
timestamp write is part of the endpoint's transaction.
Raises:
HTTPException(401): If the token is missing, invalid, or inactive.
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise HTTPException(
status_code=401,
detail="Missing or invalid SCIM bearer token",
)
token = dal.get_token_by_hash(hashed)
if not token:
raise HTTPException(
status_code=401,
detail="Invalid SCIM bearer token",
)
if not token.is_active:
raise HTTPException(
status_code=401,
detail="SCIM token has been revoked",
)
return token

View File

@@ -30,7 +30,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
# ---------------------------------------------------------------------------
@@ -196,39 +195,10 @@ class ScimServiceProviderConfig(BaseModel):
)
class ScimSchemaAttribute(BaseModel):
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
name: str
type: str
multiValued: bool = False
required: bool = False
description: str = ""
caseExact: bool = False
mutability: str = "readWrite"
returned: str = "default"
uniqueness: str = "none"
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
class ScimSchemaDefinition(BaseModel):
"""SCIM Schema definition (RFC 7643 §7).
Served at GET /scim/v2/Schemas. Describes the attributes available
on each resource type so IdPs know which fields they can provision.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
id: str
name: str
description: str
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
model_config = ConfigDict(populate_by_name=True)
schema_: str = Field(alias="schema")
required: bool
@@ -241,7 +211,7 @@ class ScimResourceType(BaseModel):
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str

View File

@@ -1,144 +0,0 @@
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaAttribute
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "User",
"name": "User",
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
}
)
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "Group",
"name": "Group",
"endpoint": "/scim/v2/Groups",
"description": "SCIM Group resource",
"schema": SCIM_GROUP_SCHEMA,
}
)
USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_USER_SCHEMA,
name="User",
description="SCIM core User schema",
attributes=[
ScimSchemaAttribute(
name="userName",
type="string",
required=True,
uniqueness="server",
description="Unique identifier for the user, typically an email address.",
),
ScimSchemaAttribute(
name="name",
type="complex",
description="The components of the user's name.",
subAttributes=[
ScimSchemaAttribute(
name="givenName",
type="string",
description="The user's first name.",
),
ScimSchemaAttribute(
name="familyName",
type="string",
description="The user's last name.",
),
ScimSchemaAttribute(
name="formatted",
type="string",
description="The full name, including all middle names and titles.",
),
],
),
ScimSchemaAttribute(
name="emails",
type="complex",
multiValued=True,
description="Email addresses for the user.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Email address value.",
),
ScimSchemaAttribute(
name="type",
type="string",
description="Label for this email (e.g. 'work').",
),
ScimSchemaAttribute(
name="primary",
type="boolean",
description="Whether this is the primary email.",
),
],
),
ScimSchemaAttribute(
name="active",
type="boolean",
description="Whether the user account is active.",
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",
description="SCIM core Group schema",
attributes=[
ScimSchemaAttribute(
name="displayName",
type="string",
required=True,
description="Human-readable name for the group.",
),
ScimSchemaAttribute(
name="members",
type="complex",
multiValued=True,
description="Members of the group.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="User ID of the group member.",
),
ScimSchemaAttribute(
name="display",
type="string",
mutability="readOnly",
description="Display name of the group member.",
),
],
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)

View File

@@ -1,13 +1,10 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -44,14 +41,6 @@ def check_ee_features_enabled() -> bool:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss — warm from DB so cold-start doesn't block EE features
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(f"Failed to load license from DB: {db_error}")
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
@@ -93,18 +82,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss (e.g. after TTL expiry). Fall back to DB so
# the /settings request doesn't falsely return GATED_ACCESS
# while the cache is cold.
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(
f"Failed to load license from DB for settings: {db_error}"
)
if metadata:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
@@ -113,7 +90,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license found in cache or DB.
# No license found.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.

View File

@@ -37,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -121,7 +121,6 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -138,8 +137,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -211,34 +208,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
if not workspace_invite_only_enabled():
whitelist = get_invited_users()
if not whitelist:
return
whitelist = get_invited_users()
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -255,13 +240,7 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -1671,10 +1650,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -0,0 +1,10 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -41,14 +41,3 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,12 +8,6 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -53,13 +47,7 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
# shared_task allows this task to be shared across celery app instances.
@@ -88,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token map stored in the
tracked via a continuation token stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -169,28 +153,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token_map,
continuation_token,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if is_continuation_token_done_for_all_slices(continuation_token_map):
if continuation_token is None and total_chunks_migrated > 0:
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -199,19 +170,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
f"Continuation token: {continuation_token}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token_map = (
raw_vespa_chunks, next_continuation_token = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token map: {next_continuation_token_map}"
f"seconds. Next continuation token: {next_continuation_token}"
)
opensearch_document_chunks, errored_chunks = (
@@ -241,11 +212,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
continuation_token=next_continuation_token,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -37,35 +37,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.

View File

@@ -0,0 +1,8 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -13,7 +13,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -373,12 +304,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -1,4 +1,3 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -618,24 +582,10 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -57,7 +57,6 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -69,18 +68,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
"""Detect XML-style marshaled tool calls emitted as plain text."""
if not text:
return False
lowered = text.lower()
return (
"<function_calls" in lowered
and "<invoke" in lowered
and "<parameter" in lowered
)
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
@@ -135,56 +122,38 @@ def _try_fallback_tool_extraction(
reasoning_but_no_answer_or_tools = (
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
)
xml_tool_call_text_detected = no_tool_calls and (
_looks_like_xml_tool_call_payload(llm_step_result.answer)
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
)
should_try_fallback = (
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
or reasoning_but_no_answer_or_tools
or xml_tool_call_text_detected
)
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
) or reasoning_but_no_answer_or_tools
if not should_try_fallback:
return llm_step_result, False
# Try to extract from answer first, then fall back to reasoning
extracted_tool_calls: list[ToolCallKickoff] = []
if llm_step_result.answer:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if (
not extracted_tool_calls
and llm_step_result.raw_answer
and llm_step_result.raw_answer != llm_step_result.answer
):
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.raw_answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if not extracted_tool_calls and llm_step_result.reasoning:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.reasoning,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if extracted_tool_calls:
logger.info(
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
"as fallback"
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
)
return (
LlmStepResult(
reasoning=llm_step_result.reasoning,
answer=llm_step_result.answer,
tool_calls=extracted_tool_calls,
raw_answer=llm_step_result.raw_answer,
),
True,
)
@@ -482,42 +451,7 @@ def construct_message_history(
if reminder_message:
result.append(reminder_message)
return _drop_orphaned_tool_call_responses(result)
def _drop_orphaned_tool_call_responses(
messages: list[ChatMessageSimple],
) -> list[ChatMessageSimple]:
"""Drop tool response messages whose tool_call_id is not in prior assistant tool calls.
This can happen when history truncation drops an ASSISTANT tool-call message but
leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama)
reject such history with an "unexpected tool call id" error.
"""
known_tool_call_ids: set[str] = set()
sanitized: list[ChatMessageSimple] = []
for msg in messages:
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
for tool_call in msg.tool_calls:
known_tool_call_ids.add(tool_call.tool_call_id)
sanitized.append(msg)
continue
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids:
sanitized.append(msg)
else:
logger.debug(
"Dropping orphaned tool response with tool_call_id=%s while "
"constructing message history",
msg.tool_call_id,
)
continue
sanitized.append(msg)
return sanitized
return result
def _create_file_tool_metadata_message(
@@ -652,7 +586,6 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
@@ -763,7 +696,6 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -903,18 +835,6 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -1,12 +1,10 @@
import json
import re
import time
import uuid
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Sequence
from html import unescape
from typing import Any
from typing import cast
@@ -20,7 +18,6 @@ from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import ChatFileType
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
@@ -59,112 +56,6 @@ from onyx.utils.text_processing import find_all_json_objects
logger = setup_logger()
_XML_INVOKE_BLOCK_RE = re.compile(
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
re.IGNORECASE | re.DOTALL,
)
_XML_PARAMETER_RE = re.compile(
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
re.IGNORECASE | re.DOTALL,
)
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
class _XmlToolCallContentFilter:
"""Streaming filter that strips XML-style tool call payload blocks from text."""
def __init__(self) -> None:
self._pending = ""
self._inside_function_calls_block = False
def process(self, content: str) -> str:
if not content:
return ""
self._pending += content
output_parts: list[str] = []
while self._pending:
pending_lower = self._pending.lower()
if self._inside_function_calls_block:
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
if end_idx == -1:
# Keep buffering until we see the close marker.
return "".join(output_parts)
# Drop the whole function_calls block.
self._pending = self._pending[
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
]
self._inside_function_calls_block = False
continue
start_idx = _find_function_calls_open_marker(pending_lower)
if start_idx == -1:
# Keep only a possible prefix of "<function_calls" in the buffer so
# marker splits across chunks are handled correctly.
tail_len = _matching_open_marker_prefix_len(self._pending)
emit_upto = len(self._pending) - tail_len
if emit_upto > 0:
output_parts.append(self._pending[:emit_upto])
self._pending = self._pending[emit_upto:]
return "".join(output_parts)
if start_idx > 0:
output_parts.append(self._pending[:start_idx])
# Enter block-stripping mode and keep scanning for close marker.
self._pending = self._pending[start_idx:]
self._inside_function_calls_block = True
return "".join(output_parts)
def flush(self) -> str:
if self._inside_function_calls_block:
# Drop any incomplete block at stream end.
self._pending = ""
self._inside_function_calls_block = False
return ""
remaining = self._pending
self._pending = ""
return remaining
def _matching_open_marker_prefix_len(text: str) -> int:
"""Return longest suffix of text that matches prefix of "<function_calls"."""
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
text_lower = text.lower()
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
for candidate_len in range(max_len, 0, -1):
if text_lower.endswith(marker_lower[:candidate_len]):
return candidate_len
return 0
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
return char is None or char in {">", " ", "\t", "\n", "\r"}
def _find_function_calls_open_marker(text_lower: str) -> int:
"""Find '<function_calls' with a valid tag boundary follower."""
search_from = 0
while True:
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
if idx == -1:
return -1
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
if _is_valid_function_calls_open_follower(follower):
return idx
search_from = idx + 1
def _sanitize_llm_output(value: str) -> str:
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
@@ -381,7 +272,14 @@ def _extract_tool_call_kickoffs(
tab_index_calculated = 0
for tool_call_data in id_to_tool_call_map.values():
if tool_call_data.get("id") and tool_call_data.get("name"):
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
try:
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
except json.JSONDecodeError:
# If parsing fails, try empty dict, most tools would fail though
logger.error(
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
)
tool_args = {}
tool_calls.append(
ToolCallKickoff(
@@ -409,9 +307,8 @@ def extract_tool_calls_from_response_text(
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
This is a fallback mechanism for when the LLM was expected to return tool calls
but didn't use the proper tool call format. It searches for tool calls embedded
in response text (JSON first, then XML-like invoke blocks) that match available
tool definitions.
but didn't use the proper tool call format. It searches for JSON objects in the
response text that match the structure of available tools.
Args:
response_text: The LLM's text response to search for tool calls
@@ -436,9 +333,10 @@ def extract_tool_calls_from_response_text(
if not tool_name_to_def:
return []
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
# Find all JSON objects in the response text
json_objects = find_all_json_objects(response_text)
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
prev_json_obj: dict[str, Any] | None = None
prev_tool_call: tuple[str, dict[str, Any]] | None = None
@@ -466,14 +364,6 @@ def extract_tool_calls_from_response_text(
prev_json_obj = json_obj
prev_tool_call = matched_tool_call
# Some providers/models emit XML-style function calls instead of JSON objects.
# Keep this as a fallback behind JSON extraction to preserve current behavior.
if not matched_tool_calls:
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
response_text=response_text,
tool_name_to_def=tool_name_to_def,
)
tool_calls: list[ToolCallKickoff] = []
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
tool_calls.append(
@@ -496,88 +386,6 @@ def extract_tool_calls_from_response_text(
return tool_calls
def _extract_xml_tool_calls_from_response_text(
response_text: str,
tool_name_to_def: dict[str, dict],
) -> list[tuple[str, dict[str, Any]]]:
"""Extract XML-style tool calls from response text.
Supports formats such as:
<function_calls>
<invoke name="internal_search">
<parameter name="queries" string="false">["foo"]</parameter>
</invoke>
</function_calls>
"""
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
invoke_attrs = invoke_match.group("attrs")
tool_name = _extract_xml_attribute(invoke_attrs, "name")
if not tool_name or tool_name not in tool_name_to_def:
continue
tool_args: dict[str, Any] = {}
invoke_body = invoke_match.group("body")
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
parameter_attrs = parameter_match.group("attrs")
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
if not parameter_name:
continue
string_attr = _extract_xml_attribute(parameter_attrs, "string")
tool_args[parameter_name] = _parse_xml_parameter_value(
raw_value=parameter_match.group("value"),
string_attr=string_attr,
)
matched_tool_calls.append((tool_name, tool_args))
return matched_tool_calls
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
"""Extract a single XML-style attribute value from a tag attribute string."""
attr_match = re.search(
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
attrs,
flags=re.IGNORECASE | re.DOTALL,
)
if not attr_match:
return None
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
"""Parse a parameter value from XML-style tool call payloads."""
value = _sanitize_llm_output(unescape(raw_value).strip())
if string_attr and string_attr.lower() == "true":
return value
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
"""Extract and parse an arguments/parameters value from a tool-call-like object.
Looks for "arguments" or "parameters" keys, handles JSON-string values,
and returns a dict if successful, or None otherwise.
"""
arguments = obj.get("arguments", obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return arguments
return None
def _try_match_json_to_tool(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
@@ -600,8 +408,13 @@ def _try_match_json_to_tool(
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
tool_name = json_obj["name"]
arguments = _resolve_tool_arguments(json_obj)
if arguments is not None:
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
@@ -609,8 +422,13 @@ def _try_match_json_to_tool(
func_obj = json_obj["function"]
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
tool_name = func_obj["name"]
arguments = _resolve_tool_arguments(func_obj)
if arguments is not None:
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 3: Tool name as key {"tool_name": {...arguments...}}
@@ -677,107 +495,6 @@ def _extract_nested_arguments_obj(
return None
def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMessage:
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
tool_calls_list = [
ToolCall(
id=tc.tool_call_id,
type="function",
function=FunctionCall(
name=tc.tool_name,
arguments=json.dumps(tc.tool_arguments),
),
)
for tc in msg.tool_calls
]
return AssistantMessage(
role="assistant",
content=msg.message or None,
tool_calls=tool_calls_list,
)
def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage:
if not msg.tool_call_id:
raise ValueError(
"Tool call response message encountered but tool_call_id is not available. "
f"Message: {msg}"
)
return ToolMessage(
role="tool",
content=msg.message,
tool_call_id=msg.tool_call_id,
)
class _HistoryMessageFormatter:
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
raise NotImplementedError
def format_tool_response_message(
self, msg: ChatMessageSimple
) -> ToolMessage | UserMessage:
raise NotImplementedError
class _DefaultHistoryMessageFormatter(_HistoryMessageFormatter):
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
return _build_structured_assistant_message(msg)
def format_tool_response_message(self, msg: ChatMessageSimple) -> ToolMessage:
return _build_structured_tool_response_message(msg)
class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
if not msg.tool_calls:
return _build_structured_assistant_message(msg)
tool_call_lines = [
(
f"[Tool Call] name={tc.tool_name} "
f"id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
)
for tc in msg.tool_calls
]
assistant_content = (
"\n".join([msg.message, *tool_call_lines])
if msg.message
else "\n".join(tool_call_lines)
)
return AssistantMessage(
role="assistant",
content=assistant_content,
tool_calls=None,
)
def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage:
if not msg.tool_call_id:
raise ValueError(
"Tool call response message encountered but tool_call_id is not available. "
f"Message: {msg}"
)
return UserMessage(
role="user",
content=f"[Tool Result] id={msg.tool_call_id}\n{msg.message}",
)
_DEFAULT_HISTORY_MESSAGE_FORMATTER = _DefaultHistoryMessageFormatter()
_OLLAMA_HISTORY_MESSAGE_FORMATTER = _OllamaHistoryMessageFormatter()
def _get_history_message_formatter(llm_config: LLMConfig) -> _HistoryMessageFormatter:
if llm_config.model_provider == LlmProviderNames.OLLAMA_CHAT:
return _OLLAMA_HISTORY_MESSAGE_FORMATTER
return _DEFAULT_HISTORY_MESSAGE_FORMATTER
def translate_history_to_llm_format(
history: list[ChatMessageSimple],
llm_config: LLMConfig,
@@ -788,10 +505,6 @@ def translate_history_to_llm_format(
handling different message types and image files for multimodal support.
"""
messages: list[ChatCompletionMessage] = []
history_message_formatter = _get_history_message_formatter(llm_config)
# Note: cacheability is computed from pre-translation ChatMessageSimple types.
# Some providers flatten tool history into plain assistant/user text, so this split
# may be less semantically meaningful, but it remains safe and order-preserving.
last_cacheable_msg_idx = -1
all_previous_msgs_cacheable = True
@@ -873,10 +586,39 @@ def translate_history_to_llm_format(
messages.append(reminder_msg)
elif msg.message_type == MessageType.ASSISTANT:
messages.append(history_message_formatter.format_assistant_message(msg))
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
tool_calls_list = [
ToolCall(
id=tc.tool_call_id,
type="function",
function=FunctionCall(
name=tc.tool_name,
arguments=json.dumps(tc.tool_arguments),
),
)
for tc in msg.tool_calls
]
assistant_msg = AssistantMessage(
role="assistant",
content=msg.message or None,
tool_calls=tool_calls_list,
)
messages.append(assistant_msg)
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
messages.append(history_message_formatter.format_tool_response_message(msg))
if not msg.tool_call_id:
raise ValueError(
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
)
tool_msg = ToolMessage(
role="tool",
content=msg.message,
tool_call_id=msg.tool_call_id,
)
messages.append(tool_msg)
else:
logger.warning(
@@ -956,8 +698,7 @@ def run_llm_step_pkt_generator(
tool_definitions: List of tool definitions available to the LLM.
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
llm: Language model interface to use for generation.
placement: Placement info (turn_index, tab_index, sub_turn_index) for
positioning packets in the conversation UI.
turn_index: Current turn index in the conversation.
state_container: Container for storing chat state (reasoning, answers).
citation_processor: Optional processor for extracting and formatting citations
from the response. If provided, processes tokens to identify citations.
@@ -969,14 +710,7 @@ def run_llm_step_pkt_generator(
custom_token_processor: Optional callable that processes each token delta
before yielding. Receives (delta, processor_state) and returns
(modified_delta, new_processor_state). Can return None for delta to skip.
max_tokens: Optional maximum number of tokens for the LLM response.
use_existing_tab_index: If True, use the tab_index from placement for all
tool calls instead of auto-incrementing.
is_deep_research: If True, treat content before tool calls as reasoning
when tool_choice is REQUIRED.
pre_answer_processing_time: Optional time spent processing before the
answer started, recorded in state_container for analytics.
timeout_override: Optional timeout override for the LLM call.
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
Yields:
Packet: Streaming packets containing:
@@ -1002,15 +736,8 @@ def run_llm_step_pkt_generator(
tab_index = placement.tab_index
sub_turn_index = placement.sub_turn_index
def _current_placement() -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
llm_msg_history = translate_history_to_llm_format(history, llm.config)
has_reasoned = False
has_reasoned = 0
if LOG_ONYX_MODEL_INTERACTIONS:
logger.debug(
@@ -1022,8 +749,6 @@ def run_llm_step_pkt_generator(
answer_start = False
accumulated_reasoning = ""
accumulated_answer = ""
accumulated_raw_answer = ""
xml_tool_call_content_filter = _XmlToolCallContentFilter()
processor_state: Any = None
@@ -1039,112 +764,6 @@ def run_llm_step_pkt_generator(
)
stream_start_time = time.monotonic()
first_action_recorded = False
def _emit_citation_results(
results: Generator[str | CitationInfo, None, None],
) -> Generator[Packet, None, None]:
"""Yield packets for citation processor results (str or CitationInfo)."""
nonlocal accumulated_answer
for result in results:
if isinstance(result, str):
accumulated_answer += result
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=_current_placement(),
obj=result,
)
if state_container:
state_container.add_emitted_citation(result.citation_number)
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
nonlocal reasoning_start
nonlocal has_reasoned
nonlocal turn_index
nonlocal sub_turn_index
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = True
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
nonlocal accumulated_answer
nonlocal accumulated_reasoning
nonlocal answer_start
nonlocal reasoning_start
nonlocal turn_index
nonlocal sub_turn_index
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
accumulated_reasoning += content_chunk
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
obj=ReasoningDelta(reasoning=content_chunk),
)
reasoning_start = True
return
# Normal flow for AUTO or NONE tool choice
yield from _close_reasoning_if_active()
if not answer_start:
# Store pre-answer processing time in state container for save_chat
if state_container and pre_answer_processing_time is not None:
state_container.set_pre_answer_processing_time(
pre_answer_processing_time
)
yield Packet(
placement=_current_placement(),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
),
)
answer_start = True
if citation_processor:
yield from _emit_citation_results(
citation_processor.process_token(content_chunk)
)
else:
accumulated_answer += content_chunk
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=content_chunk),
)
for packet in llm.stream(
prompt=llm_msg_history,
tools=tool_definitions,
@@ -1203,34 +822,152 @@ def run_llm_step_pkt_generator(
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
reasoning_start = True
if delta.content:
# Keep raw content for fallback extraction. Display content can be
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
accumulated_raw_answer += delta.content
filtered_content = xml_tool_call_content_filter.process(delta.content)
if filtered_content:
yield from _emit_content_chunk(filtered_content)
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
# Treat content as reasoning when we know tool calls are coming
accumulated_reasoning += delta.content
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.content),
)
reasoning_start = True
else:
# Normal flow for AUTO or NONE tool choice
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
if not answer_start:
# Store pre-answer processing time in state container for save_chat
if state_container and pre_answer_processing_time is not None:
state_container.set_pre_answer_processing_time(
pre_answer_processing_time
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
),
)
answer_start = True
if citation_processor:
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(
accumulated_answer
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(
result.citation_number
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=delta.content),
)
if delta.tool_calls:
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
# Flush any tail text buffered while checking for split "<function_calls" markers.
filtered_content_tail = xml_tool_call_content_filter.flush()
if filtered_content_tail:
yield from _emit_content_chunk(filtered_content_tail)
# Flush custom token processor to get any final tool calls
if custom_token_processor:
flush_delta, processor_state = custom_token_processor(None, processor_state)
@@ -1286,14 +1023,50 @@ def run_llm_step_pkt_generator(
# This may happen if the custom token processor is used to modify other packets into reasoning
# Then there won't necessarily be anything else to come after the reasoning tokens
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
reasoning_start = False
# Flush any remaining content from citation processor
# Reasoning is always first so this should use the post-incremented value of turn_index
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
# as clickable items and will be stripped out instead.
if citation_processor:
yield from _emit_citation_results(citation_processor.process_token(None))
for result in citation_processor.process_token(None):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
@@ -1315,9 +1088,8 @@ def run_llm_step_pkt_generator(
reasoning=accumulated_reasoning if accumulated_reasoning else None,
answer=accumulated_answer if accumulated_answer else None,
tool_calls=tool_calls if tool_calls else None,
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
),
has_reasoned,
bool(has_reasoned),
)
@@ -1372,4 +1144,4 @@ def run_llm_step(
emitter.emit(packet)
except StopIteration as e:
llm_step_result, has_reasoned = e.value
return llm_step_result, has_reasoned
return llm_step_result, bool(has_reasoned)

View File

@@ -185,6 +185,3 @@ class LlmStepResult(BaseModel):
reasoning: str | None
answer: str | None
tool_calls: list[ToolCallKickoff] | None
# Raw LLM text before any display-oriented filtering/sanitization.
# Used for fallback tool-call extraction when providers emit calls as text.
raw_answer: str | None = None

View File

@@ -10,7 +10,6 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -126,7 +125,6 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -134,8 +132,6 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None

View File

@@ -251,9 +251,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -265,18 +263,6 @@ OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
@@ -284,9 +270,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
@@ -642,14 +625,6 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

View File

@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -46,7 +46,6 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
@@ -157,7 +156,10 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
return False
# For shared drive content, the root has id == driveId
return bool(drive_id and folder_id == drive_id)
if drive_id and folder_id == drive_id:
return True
return False
def _public_access() -> ExternalAccess:
@@ -614,16 +616,6 @@ class GoogleDriveConnector(
# empty parents due to permission limitations)
# Check shared drive root first (simple ID comparison)
if _is_shared_drive_root(folder):
# files().get() returns 'Drive' for shared drive roots;
# fetch the real name via drives().get().
# Try both the retriever and admin since the admin may
# not have access to private shared drives.
drive_name = self._get_shared_drive_name(
current_id, file.user_email
)
if drive_name:
node.display_name = drive_name
node.node_type = HierarchyNodeType.SHARED_DRIVE
reached_terminal = True
break
@@ -699,15 +691,6 @@ class GoogleDriveConnector(
)
return None
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
"""Fetch the name of a shared drive, trying both the retriever and admin."""
for email in {retriever_email, self.primary_admin_email}:
svc = get_drive_service(self.creds, email)
name = get_shared_drive_name(svc, drive_id)
if name:
return name
return None
def get_all_drive_ids(self) -> set[str]:
return self._get_all_drives_for_user(self.primary_admin_email)

View File

@@ -154,26 +154,6 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return HIERARCHY_FIELDS
def get_shared_drive_name(
service: Resource,
drive_id: str,
) -> str | None:
"""Fetch the actual name of a shared drive via the drives().get() API.
The files().get() API returns 'Drive' as the name for shared drive root
folders. Only drives().get() returns the real user-assigned name.
"""
try:
drive = service.drives().get(driveId=drive_id, fields="name").execute()
return drive.get("name")
except HttpError as e:
if e.resp.status in (403, 404):
logger.debug(f"Cannot access drive {drive_id}: {e}")
else:
raise
return None
def get_external_access_for_folder(
folder: GoogleDriveFileType,
google_domain: str,

File diff suppressed because it is too large Load Diff

View File

@@ -50,15 +50,12 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -66,15 +63,11 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
# impls for BaseConnector
@@ -83,7 +76,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{self.authority_host}/{teams_directory_id}"
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -98,7 +91,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if not isinstance(token, dict):

View File

@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session | None = None,
db_session: Session,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -928,7 +925,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session (optional if search_settings provided)
db_session: Database session
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1156,10 +1153,7 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -43,7 +41,7 @@ def _build_index_filters(
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -51,8 +49,6 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
@@ -107,14 +103,9 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
@@ -261,15 +252,11 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -310,7 +297,6 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -329,8 +315,6 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -52,14 +50,9 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
db_session: Session,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
query_embedding = get_query_embedding(query_request.query, db_session)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -85,9 +78,7 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
db_session: Session,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -97,22 +88,14 @@ def search_chunks(
else None
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -131,10 +114,7 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
(_embed_and_search, (query_request, document_index, db_session))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -1,74 +0,0 @@
"""Base Data Access Layer (DAL) for database operations.
The DAL pattern groups related database operations into cohesive classes
with explicit session management. It supports two usage modes:
1. **External session** (FastAPI endpoints) — the caller provides a session
whose lifecycle is managed by FastAPI's dependency injection.
2. **Self-managed session** (Celery tasks, scripts) — the DAL creates its
own session via the tenant-aware session factory.
Subclasses add domain-specific query methods while inheriting session
management. See ``ee.onyx.db.scim.ScimDAL`` for a concrete example.
Example (FastAPI)::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.get("/users")
def list_users(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
return dal.list_user_mappings(...)
Example (Celery)::
with ScimDAL.from_tenant("tenant_abc") as dal:
dal.create_user_mapping(...)
dal.commit()
"""
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_tenant
class DAL:
"""Base Data Access Layer.
Holds a SQLAlchemy session and provides transaction control helpers.
Subclasses add domain-specific query methods.
"""
def __init__(self, db_session: Session) -> None:
self._session = db_session
@property
def session(self) -> Session:
"""Direct access to the underlying session for advanced use cases."""
return self._session
def commit(self) -> None:
self._session.commit()
def flush(self) -> None:
self._session.flush()
def rollback(self) -> None:
self._session.rollback()
@classmethod
@contextmanager
def from_tenant(cls, tenant_id: str) -> Generator["DAL", None, None]:
"""Create a DAL with a self-managed session for the given tenant.
The session is automatically closed when the context manager exits.
The caller must explicitly call ``commit()`` to persist changes.
"""
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield cls(session)

View File

@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
.options(selectinload(DocumentSetDBModel.federated_connectors))
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_documents_for_document_set_paginated(

View File

@@ -232,12 +232,6 @@ class BuildSessionStatus(str, PyEnum):
IDLE = "idle"
class SharingScope(str, PyEnum):
PRIVATE = "private"
PUBLIC_ORG = "public_org"
PUBLIC_GLOBAL = "public_global"
class SandboxStatus(str, PyEnum):
PROVISIONING = "provisioning"
RUNNING = "running"

View File

@@ -430,7 +430,7 @@ def fetch_existing_models(
def fetch_existing_llm_providers(
db_session: Session,
flow_type_filter: list[LLMModelFlowType],
flow_types: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -438,27 +438,30 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_type_filter: List of flow types to filter by, empty list for no filter
flow_types: List of flow types to filter by
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
stmt = select(LLMProviderModel)
if flow_type_filter:
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
.distinct()
)
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
@@ -794,15 +797,13 @@ def sync_auto_mode_models(
changes += 1
else:
# Add new model - all models from GitHub config are visible
insert_new_model_configuration__no_commit(
db_session=db_session,
new_model = ModelConfiguration(
llm_provider_id=provider.id,
model_name=model_config.name,
supported_flows=[LLMModelFlowType.CHAT],
is_visible=True,
max_input_tokens=None,
name=model_config.name,
display_name=model_config.display_name,
is_visible=True,
)
db_session.add(new_model)
changes += 1
# In Auto mode, default model is always set from GitHub config

View File

@@ -77,7 +77,6 @@ from onyx.db.enums import (
ThemePreference,
DefaultAppMode,
SwitchoverType,
SharingScope,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
@@ -287,7 +286,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
"Credential", back_populates="user", lazy="joined"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -321,6 +320,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -1040,9 +1040,7 @@ class OpenSearchTenantMigrationRecord(Base):
nullable=False,
)
# Opaque continuation token from Vespa's Visit API.
# NULL means "not started".
# Otherwise contains a serialized mapping between slice ID and continuation
# token for that slice.
# NULL means "not started" or "visit completed".
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
Text, nullable=True
)
@@ -1066,9 +1064,6 @@ class OpenSearchTenantMigrationRecord(Base):
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column(
Integer, nullable=True
)
class KGEntityType(Base):
@@ -4717,12 +4712,6 @@ class BuildSession(Base):
demo_data_enabled: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
sharing_scope: Mapped[SharingScope] = mapped_column(
String,
nullable=False,
default=SharingScope.PRIVATE,
server_default="private",
)
# Relationships
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])
@@ -4939,7 +4928,6 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False

View File

@@ -4,7 +4,6 @@ This module provides functions to track the progress of migrating documents
from Vespa to OpenSearch.
"""
import json
from datetime import datetime
from datetime import timezone
@@ -13,9 +12,6 @@ from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_SLICE_COUNT,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
@@ -247,37 +243,29 @@ def should_document_migration_be_permanently_failed(
def get_vespa_visit_state(
db_session: Session,
) -> tuple[dict[int, str | None], int]:
) -> tuple[str | None, int]:
"""Gets the current Vespa migration state from the tenant migration record.
Requires the OpenSearchTenantMigrationRecord to exist.
Returns:
Tuple of (continuation_token_map, total_chunks_migrated).
Tuple of (continuation_token, total_chunks_migrated). continuation_token
is None if not started or completed.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
if record.vespa_visit_continuation_token is None:
continuation_token_map: dict[int, str | None] = {
slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT)
}
else:
json_loaded_continuation_token_map = json.loads(
record.vespa_visit_continuation_token
)
continuation_token_map = {
int(key): value for key, value in json_loaded_continuation_token_map.items()
}
return continuation_token_map, record.total_chunks_migrated
return (
record.vespa_visit_continuation_token,
record.total_chunks_migrated,
)
def update_vespa_visit_progress_with_commit(
db_session: Session,
continuation_token_map: dict[int, str | None],
continuation_token: str | None,
chunks_processed: int,
chunks_errored: int,
approx_chunk_count_in_vespa: int | None,
) -> None:
"""Updates the Vespa migration progress and commits.
@@ -285,26 +273,19 @@ def update_vespa_visit_progress_with_commit(
Args:
db_session: SQLAlchemy session.
continuation_token_map: The new continuation token map. None entry means
the visit is complete for that slice.
continuation_token: The new continuation token. None means the visit
is complete.
chunks_processed: Number of chunks processed in this batch (added to
the running total).
chunks_errored: Number of chunks errored in this batch (added to the
running errored total).
approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If
None, the existing value is used.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.vespa_visit_continuation_token = json.dumps(continuation_token_map)
record.vespa_visit_continuation_token = continuation_token
record.total_chunks_migrated += chunks_processed
record.total_chunks_errored += chunks_errored
record.approx_chunk_count_in_vespa = (
approx_chunk_count_in_vespa
if approx_chunk_count_in_vespa is not None
else record.approx_chunk_count_in_vespa
)
db_session.commit()
@@ -372,27 +353,25 @@ def build_sanitized_to_original_doc_id_mapping(
def get_opensearch_migration_state(
db_session: Session,
) -> tuple[int, datetime | None, datetime | None, int | None]:
) -> tuple[int, datetime | None, datetime | None]:
"""Returns the state of the Vespa to OpenSearch migration.
If the tenant migration record is not found, returns defaults of 0, None,
None, None.
None.
Args:
db_session: SQLAlchemy session.
Returns:
Tuple of (total_chunks_migrated, created_at, migration_completed_at,
approx_chunk_count_in_vespa).
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return 0, None, None, None
return 0, None, None
return (
record.total_chunks_migrated,
record.created_at,
record.migration_completed_at,
record.approx_chunk_count_in_vespa,
)

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -32,59 +31,53 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
.limit(1)
)
if not user:
row = result.first()
if not row:
return None
_schedule_pat_last_used_update(hashed_token, now)
return user
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
asyncio.create_task(_update_last_used())
return user
def create_pat(

View File

@@ -28,7 +28,6 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -421,16 +420,9 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -461,16 +453,7 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -567,16 +550,9 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.user),
)
@@ -635,16 +611,7 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),

View File

@@ -54,9 +54,6 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
@@ -709,12 +706,10 @@ class OpenSearchClient:
)
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
)
search_hits.append(search_hit)
logger.debug(

View File

@@ -10,31 +10,31 @@ EF_CONSTRUCTION = 256
# quality but increase memory footprint. Values typically range between 12 - 48.
M = 32 # Set relatively high for better accuracy.
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
# a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
# Number of vectors to examine for top k neighbors for the HNSW method.
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
# Should be >= DEFAULT_K_NUM_CANDIDATES for good recall; higher = better accuracy, slower search.
# Bumped this to 1000, for dataset of low 10,000 docs, did not see improvement in recall.
EF_SEARCH = 256
# The default number of neighbors to consider for knn vector similarity search.
# We need this higher than the number of results because the scoring is hybrid.
# If there is only 1 query, setting k equal to the number of results is enough,
# but since there is heavy reordering due to hybrid scoring, we need to set k higher.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, more query cost.
DEFAULT_K_NUM_CANDIDATES = 50 # TODO likely need to bump this way higher
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
# rather than an independent scoring component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
SEARCH_TITLE_KEYWORD_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.4
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.4
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_TITLE_KEYWORD_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
SEARCH_CONTENT_KEYWORD_WEIGHT,
]
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0

View File

@@ -842,8 +842,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights

View File

@@ -11,7 +11,6 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -55,11 +54,6 @@ SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids"
# Faiss was also tried but it didn't have any benefits
# NMSLIB is deprecated, not recommended
OPENSEARCH_KNN_ENGINE = "lucene"
def get_opensearch_doc_chunk_id(
tenant_state: TenantState,
document_id: str,
@@ -349,9 +343,6 @@ class DocumentSchema:
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
# Language analyzer (e.g. english) stems at index and search time for variant matching.
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"fields": {
# Subfield accessed as title.keyword. Not indexed for
# values longer than 256 chars.
@@ -366,7 +357,9 @@ class DocumentSchema:
CONTENT_FIELD_NAME: {
"type": "text",
"store": True,
"analyzer": OPENSEARCH_TEXT_ANALYZER,
# This makes highlighting text during queries more efficient
# at the cost of disk space. See
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
"index_options": "offsets",
},
TITLE_VECTOR_FIELD_NAME: {
@@ -375,7 +368,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": OPENSEARCH_KNN_ENGINE,
"engine": "lucene",
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
@@ -387,7 +380,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": OPENSEARCH_KNN_ENGINE,
"engine": "lucene",
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},

View File

@@ -6,16 +6,13 @@ from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
from onyx.document_index.opensearch.constants import DEFAULT_K_NUM_CANDIDATES
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
@@ -243,9 +240,6 @@ class DocumentQuery:
Returns:
A dictionary representing the final hybrid search query.
"""
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
@@ -253,7 +247,7 @@ class DocumentQuery:
)
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector
query_text, query_vector, num_candidates=DEFAULT_K_NUM_CANDIDATES
)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
@@ -281,31 +275,25 @@ class DocumentQuery:
hybrid_search_query: dict[str, Any] = {
"hybrid": {
"queries": hybrid_search_subqueries,
# Max results per subquery per shard before aggregation. Ensures keyword and vector
# subqueries contribute equally to the candidate pool for hybrid fusion.
# Sources:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
# Sources:
# Applied to all the sub-queries. Source:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
# Explain is for scoring breakdowns.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
return final_hybrid_search_body
@@ -367,12 +355,7 @@ class DocumentQuery:
@staticmethod
def _get_hybrid_search_subqueries(
query_text: str,
query_vector: list[float],
# The default number of neighbors to consider for knn vector similarity search.
# This is higher than the number of results because the scoring is hybrid.
# for a detailed breakdown, see where the default value is set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
query_text: str, query_vector: list[float], num_candidates: int
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -384,8 +367,9 @@ class DocumentQuery:
Matches:
- Title vector
- Title keyword
- Content vector
- Keyword (title + content, match and phrase)
- Content keyword + phrase
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
@@ -406,9 +390,9 @@ class DocumentQuery:
NOTE: Options considered and rejected:
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
and very low number of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
less performant so not really any reason to do it.
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). This is reasonable but in reality seeing the
user usage patterns, this is not very common and people tend to not be confused when a miss happens for this reason.
In testing datasets, this makes recall slightly worse.
Args:
query_text: The text of the query to search for.
@@ -417,27 +401,19 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, content vector, keyword (title + content).
# pipeline weights: title vector, title keyword, content vector,
# content keyword.
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
"k": num_candidates,
}
}
},
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 3. Keyword (title + content) match and phrase search.
# 2. Title keyword + phrase search.
{
"bool": {
"should": [
@@ -445,10 +421,8 @@ class DocumentQuery:
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
@@ -456,17 +430,35 @@ class DocumentQuery:
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
"boost": 0.2,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}
},
]
}
},
# 3. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
}
}
},
# 4. Content keyword + phrase search.
{
"bool": {
"should": [
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
"boost": 1.0,
}
}
},
@@ -474,7 +466,9 @@ class DocumentQuery:
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}

View File

@@ -10,12 +10,6 @@ from typing import cast
import httpx
from retry import retry
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.transformer import (
FIELDS_NEEDED_FOR_TRANSFORMATION,
)
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.context.search.models import IndexFilters
@@ -283,139 +277,54 @@ def get_chunks_via_visit_api(
def get_all_chunks_paginated(
index_name: str,
tenant_state: TenantState,
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict], dict[int, str | None]]:
continuation_token: str | None = None,
page_size: int = 1_000,
) -> tuple[list[dict], str | None]:
"""Gets all chunks in Vespa matching the filters, paginated.
Uses the Visit API with slicing. Each continuation token map entry is for a
different slice. The number of entries determines the number of slices.
Args:
index_name: The name of the Vespa index to visit.
tenant_state: The tenant state to filter by.
continuation_token_map: Map of slice ID to a token returned by Vespa
representing a page offset. None to start from the beginning of the
slice.
continuation_token: Token returned by Vespa representing a page offset.
None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
def _get_all_chunks_paginated_for_slice(
index_name: str,
tenant_state: TenantState,
slice_id: int,
total_slices: int,
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict], str | None]:
if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}."
)
return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION)
params: dict[str, str | int | None] = {
"selection": selection,
"fieldSet": field_set,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
"slices": total_slices,
"sliceId": slice_id,
}
if continuation_token is not None:
params["continuation"] = continuation_token
response: httpx.Response | None = None
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
error_message = (
response.json().get("message") if response else "No response"
)
logger.error("Error message from response: %s", error_message)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
# NOTE: If we see a falsey value for "continuation" in the response we
# assume we are done and return
# FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead.
next_continuation_token = (
response_data.get("continuation")
or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
)
chunks = [chunk["fields"] for chunk in response_data.get("documents", [])]
if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}."
)
return chunks, next_continuation_token
total_slices = len(continuation_token_map)
if total_slices < 1:
raise ValueError("continuation_token_map must have at least one entry.")
# We want to guarantee that these invocations are ordered by slice_id,
# because we read in the same order below when parsing parallel_results.
functions_with_args: list[tuple[Callable, tuple]] = [
(
_get_all_chunks_paginated_for_slice,
(
index_name,
tenant_state,
slice_id,
total_slices,
continuation_token,
page_size,
),
)
for slice_id, continuation_token in sorted(continuation_token_map.items())
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
if len(parallel_results) != total_slices:
raise RuntimeError(
f"Expected {total_slices} parallel results, but got {len(parallel_results)}."
)
chunks: list[dict] = []
next_continuation_token_map: dict[int, str | None] = {
key: value for key, value in continuation_token_map.items()
params: dict[str, str | int | None] = {
"selection": selection,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
}
for i, parallel_result in enumerate(parallel_results):
if i not in next_continuation_token_map:
raise RuntimeError(f"Slice {i} is not in the continuation token map.")
if parallel_result is None:
logger.error(
f"Failed to get chunks for slice {i} of {total_slices}. "
"The continuation token for this slice will not be updated."
)
continue
chunks.extend(parallel_result[0])
next_continuation_token_map[i] = parallel_result[1]
if continuation_token is not None:
params["continuation"] = continuation_token
return chunks, next_continuation_token_map
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to get chunks in Vespa."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
return [
chunk["fields"] for chunk in response_data.get("documents", [])
], response_data.get("continuation") or None
# TODO(rkuo): candidate for removal if not being used

View File

@@ -56,7 +56,6 @@ from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -653,9 +652,9 @@ class VespaDocumentIndex(DocumentIndex):
def get_all_raw_document_chunks_paginated(
self,
continuation_token_map: dict[int, str | None],
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict[str, Any]], dict[int, str | None]]:
) -> tuple[list[dict[str, Any]], str | None]:
"""Gets all the chunks in Vespa, paginated.
Used in the chunk-level Vespa-to-OpenSearch migration task.
@@ -663,21 +662,21 @@ class VespaDocumentIndex(DocumentIndex):
Args:
continuation_token: Token returned by Vespa representing a page
offset. None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
raw_chunks, next_continuation_token_map = get_all_chunks_paginated(
raw_chunks, next_continuation_token = get_all_chunks_paginated(
index_name=self._index_name,
tenant_state=TenantState(
tenant_id=self._tenant_id, multitenant=MULTI_TENANT
),
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=page_size,
)
return raw_chunks, next_continuation_token_map
return raw_chunks, next_continuation_token
def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None:
"""Indexes raw document chunks into Vespa.
@@ -703,32 +702,3 @@ class VespaDocumentIndex(DocumentIndex):
json={"fields": chunk},
)
response.raise_for_status()
def get_chunk_count(self) -> int:
"""Returns the exact number of document chunks in Vespa for this tenant.
Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked`
to get an exact count without fetching any document data.
Includes large chunks. There is no way to filter these out using the
Search API.
"""
where_clause = (
f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true"
)
yql = (
f"select documentid from {self._index_name} "
f"where {where_clause} "
f"limit 0"
)
params: dict[str, str | int] = {
"yql": yql,
"ranking.profile": "unranked",
"timeout": VESPA_TIMEOUT,
}
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
response_data = response.json()
return response_data["root"]["fields"]["totalCount"]

View File

@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -60,7 +59,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -46,7 +45,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -11,7 +9,6 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -64,6 +64,21 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -144,6 +159,11 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1300,6 +1320,11 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1340,6 +1365,16 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1470,6 +1505,26 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1650,6 +1705,16 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3161,6 +3226,15 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3175,6 +3249,16 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3666,6 +3750,16 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3785,6 +3879,20 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -1,7 +1,5 @@
import json
import pathlib
import threading
import time
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -25,11 +23,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
_recommendations_cache_lock = threading.Lock()
_cached_recommendations: LLMRecommendations | None = None
_cached_recommendations_time: float = 0.0
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level.
@@ -48,40 +41,19 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
}
def _load_bundled_recommendations() -> LLMRecommendations:
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations from the GitHub config."""
recommendations_from_github = fetch_llm_recommendations_from_github()
if recommendations_from_github:
return recommendations_from_github
# Fall back to json bundled with code
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
with open(json_path, "r") as f:
json_config = json.load(f)
return LLMRecommendations.model_validate(json_config)
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations, with an in-memory cache to avoid
hitting GitHub on every API request."""
global _cached_recommendations, _cached_recommendations_time
now = time.monotonic()
if (
_cached_recommendations is not None
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
with _recommendations_cache_lock:
# Double-check after acquiring lock
if (
_cached_recommendations is not None
and (time.monotonic() - _cached_recommendations_time)
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
recommendations_from_github = fetch_llm_recommendations_from_github()
result = recommendations_from_github or _load_bundled_recommendations()
_cached_recommendations = result
_cached_recommendations_time = time.monotonic()
return result
recommendations_from_json = LLMRecommendations.model_validate(json_config)
return recommendations_from_json
def is_obsolete_model(model_name: str, provider: str) -> bool:
@@ -243,23 +215,6 @@ def model_configurations_for_provider(
) -> list[ModelConfigurationView]:
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
recommended_visible_models_names = [m.name for m in recommended_visible_models]
# Preserve provider-defined ordering while de-duplicating.
model_names: list[str] = []
seen_model_names: set[str] = set()
for model_name in (
fetch_models_for_provider(provider_name) + recommended_visible_models_names
):
if model_name in seen_model_names:
continue
seen_model_names.add(model_name)
model_names.append(model_name)
# Vertex model list can be large and mixed-vendor; alphabetical ordering
# makes model discovery easier in admin selection UIs.
if provider_name == VERTEXAI_PROVIDER_NAME:
model_names = sorted(model_names, key=str.lower)
return [
ModelConfigurationView(
name=model_name,
@@ -267,7 +222,8 @@ def model_configurations_for_provider(
max_input_tokens=get_max_input_tokens(model_name, provider_name),
supports_image_input=model_supports_image_input(model_name, provider_name),
)
for model_name in model_names
for model_name in set(fetch_models_for_provider(provider_name))
| set(recommended_visible_models_names)
]

View File

@@ -21,6 +21,7 @@ from fastapi.routing import APIRoute
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from starlette.types import Lifespan
@@ -52,7 +53,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.connection_warmup import warm_up_connections
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
@@ -64,7 +64,7 @@ from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.build.api.api import public_build_router
from onyx.server.features.build.api.api import nextjs_assets_router
from onyx.server.features.build.api.api import router as build_router
from onyx.server.features.default_assistant.api import (
router as default_assistant_router,
@@ -115,10 +115,6 @@ from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.metrics.postgres_connection_pool import (
setup_postgres_connection_pool_metrics,
)
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
@@ -142,7 +138,6 @@ from onyx.setup import setup_onyx
from onyx.tracing.setup import setup_tracing
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_endpoint_context_middleware
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
@@ -271,17 +266,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
# Register pool metrics now that engines are created.
# HTTP instrumentation is set up earlier in get_application() since it
# adds middleware (which Starlette forbids after the app has started).
setup_postgres_connection_pool_metrics(
engines={
"sync": SqlEngine.get_engine(),
"async": get_sqlalchemy_async_engine(),
"readonly": SqlEngine.get_readonly_engine(),
},
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)
@@ -394,8 +378,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
@@ -576,18 +560,12 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
add_onyx_request_id_middleware(application, "API", logger)
# Set endpoint context for per-endpoint DB pool attribution metrics.
# Must be registered after all routes are added.
add_endpoint_context_middleware(application)
# HTTP request metrics (latency histograms, in-progress gauge, slow request
# counter). Must be called here — before the app starts — because the
# instrumentator adds middleware via app.add_middleware().
setup_prometheus_metrics(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
use_route_function_names_as_operation_ids(application)
return application

View File

@@ -69,12 +69,6 @@ Very briefly describe the image(s) generated. Do not include any links or attach
""".strip()
FILE_REMINDER = """
Your code execution generated file(s) with download links.
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
""".strip()
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "

View File

@@ -70,8 +70,6 @@ GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
"""
MEMORY_GUIDANCE = """

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -59,9 +59,6 @@ PUBLIC_ENDPOINT_SPECS = [
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
# craft webapp proxy — access enforced per-session via sharing_scope in handler
("/build/sessions/{session_id}/webapp", {"GET"}),
("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}),
]
@@ -105,9 +102,6 @@ def check_router_auth(
current_cloud_superuser = fetch_ee_implementation_or_noop(
"onyx.auth.users", "current_cloud_superuser"
)
verify_scim_token = fetch_ee_implementation_or_noop(
"onyx.server.scim.auth", "verify_scim_token"
)
for route in application.routes:
# explicitly marked as public
@@ -131,7 +125,6 @@ def check_router_auth(
or depends_fn == current_chat_accessible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
or depends_fn == verify_scim_token
):
found_auth = True
break

View File

@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -988,7 +987,6 @@ def get_connector_status(
user=user,
eager_load_connector=True,
eager_load_credential=True,
eager_load_user=True,
get_editable=False,
)
@@ -1002,23 +1000,11 @@ def get_connector_status(
relationship.user_group_id
)
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
connector_to_credential_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
cc_pair.credential_id
)
return [
ConnectorStatus(
cc_pair_id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=connector_to_credential_ids.get(
cc_pair.connector_id, []
),
),
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
access_type=cc_pair.access_type,
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
@@ -1073,27 +1059,15 @@ def get_connector_indexing_status(
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
# Get editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, True, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, True, None, True, True, True, True, request.source),
),
# Get federated connectors
(fetch_all_federated_connectors_parallel, ()),
# Get most recent index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, False
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
# Get most recent finished index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, True
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
]
if user and user.role == UserRole.ADMIN:
@@ -1110,10 +1084,8 @@ def get_connector_indexing_status(
parallel_functions.append(
# Get non-editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, False, None, True, True, True, True, request.source),
),
)
@@ -1952,8 +1924,6 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
@@ -1961,7 +1931,6 @@ def get_basic_connector_indexing_status(
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -365,8 +365,7 @@ class CCPairFullInfo(BaseModel):
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector,
credential_ids=[cc_pair_model.credential_id],
cc_pair_model.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_model.credential

View File

@@ -1,5 +1,4 @@
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
import httpx
@@ -8,19 +7,16 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import ProcessingMode
from onyx.db.enums import SharingScope
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -221,15 +217,12 @@ def get_build_connectors(
return BuildConnectorListResponse(connectors=connectors)
# Headers to skip when proxying.
# Hop-by-hop headers must not be forwarded, and set-cookie is stripped to
# prevent LLM-generated apps from setting cookies on the parent Onyx domain.
# Headers to skip when proxying (hop-by-hop headers)
EXCLUDED_HEADERS = {
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
"set-cookie",
}
@@ -287,7 +280,7 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
db_session: Database session
Returns:
Internal URL to proxy requests to
The internal URL to proxy requests to
Raises:
HTTPException: If session not found, port not allocated, or sandbox not found
@@ -301,10 +294,12 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
if session.user_id is None:
raise HTTPException(status_code=404, detail="User not found")
# Get the user's sandbox to get the sandbox_id
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
if sandbox is None:
raise HTTPException(status_code=404, detail="Sandbox not found")
# Use sandbox manager to get the correct internal URL
sandbox_manager = get_sandbox_manager()
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
@@ -370,73 +365,71 @@ def _proxy_request(
raise HTTPException(status_code=502, detail="Bad gateway")
def _check_webapp_access(
session_id: UUID, user: User | None, db_session: Session
) -> BuildSession:
"""Check if user can access a session's webapp.
- public_global: accessible by anyone (no auth required)
- public_org: accessible by any authenticated user
- private: only accessible by the session owner
"""
session = db_session.get(BuildSession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.sharing_scope == SharingScope.PUBLIC_GLOBAL:
return session
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id:
raise HTTPException(status_code=404, detail="Session not found")
return session
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
def _offline_html_response() -> Response:
"""Return a branded Craft HTML page when the sandbox is not reachable.
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
terminal window aesthetic with Minecraft-themed typing animation.
"""
html = _OFFLINE_HTML_PATH.read_text()
return Response(content=html, status_code=503, media_type="text/html")
# Public router for webapp proxy — no authentication required
# (access controlled per-session via sharing_scope)
public_build_router = APIRouter(prefix="/build")
@public_build_router.get("/sessions/{session_id}/webapp", response_model=None)
@public_build_router.get(
"/sessions/{session_id}/webapp/{path:path}", response_model=None
)
def get_webapp(
@router.get("/sessions/{session_id}/webapp", response_model=None)
def get_webapp_root(
session_id: UUID,
request: Request,
path: str = "",
user: User | None = Depends(optional_user),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy the webapp for a specific session (root and subpaths).
"""Proxy the root path of the webapp for a specific session."""
return _proxy_request("", request, session_id, db_session)
Accessible without authentication when sharing_scope is public_global.
Returns a friendly offline page when the sandbox is not running.
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
def get_webapp_path(
session_id: UUID,
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
return _proxy_request(path, request, session_id, db_session)
# Separate router for Next.js static assets at /_next/*
# This is needed because Next.js apps may reference assets with root-relative paths
# that don't get rewritten. The session_id is extracted from the Referer header.
nextjs_assets_router = APIRouter()
def _extract_session_from_referer(request: Request) -> UUID | None:
"""Extract session_id from the Referer header.
Expects Referer to contain /api/build/sessions/{session_id}/webapp
"""
try:
_check_webapp_access(session_id, user, db_session)
except HTTPException as e:
if e.status_code == 401:
return RedirectResponse(url="/auth/login", status_code=302)
raise
try:
return _proxy_request(path, request, session_id, db_session)
except HTTPException as e:
if e.status_code in (502, 503, 504):
return _offline_html_response()
raise
import re
referer = request.headers.get("referer", "")
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
if match:
try:
return UUID(match.group(1))
except ValueError:
return None
return None
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
def get_nextjs_assets(
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy Next.js static assets requested at root /_next/ path.
The session_id is extracted from the Referer header since these requests
come from within the iframe context.
"""
session_id = _extract_session_from_referer(request)
if not session_id:
raise HTTPException(
status_code=400,
detail="Could not determine session from request context",
)
return _proxy_request(f"_next/{path}", request, session_id, db_session)
# =============================================================================

View File

@@ -10,7 +10,6 @@ from onyx.configs.constants import MessageType
from onyx.db.enums import ArtifactType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.server.features.build.sandbox.models import (
FilesystemEntry as FileSystemEntry,
)
@@ -108,7 +107,6 @@ class SessionResponse(BaseModel):
nextjs_port: int | None
sandbox: SandboxResponse | None
artifacts: list[ArtifactResponse]
sharing_scope: SharingScope
@classmethod
def from_model(
@@ -131,7 +129,6 @@ class SessionResponse(BaseModel):
nextjs_port=session.nextjs_port,
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
sharing_scope=session.sharing_scope,
)
@@ -162,19 +159,6 @@ class SessionListResponse(BaseModel):
sessions: list[SessionResponse]
class SetSessionSharingRequest(BaseModel):
"""Request to set the sharing scope of a session."""
sharing_scope: SharingScope
class SetSessionSharingResponse(BaseModel):
"""Response after setting session sharing scope."""
session_id: str
sharing_scope: SharingScope
# ===== Message Models =====
class MessageRequest(BaseModel):
"""Request to send a message to the CLI agent."""
@@ -260,7 +244,6 @@ class WebappInfo(BaseModel):
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
status: str # Sandbox status (running, terminated, etc.)
ready: bool # Whether the NextJS dev server is actually responding
sharing_scope: SharingScope
# ===== File Upload Models =====

Some files were not shown because too many files have changed in this diff Show More