Compare commits

..

5 Commits

Author SHA1 Message Date
Jessica Singh
d0c9d36692 changes to ecs fargate 2026-02-18 12:17:13 -08:00
Jessica Singh
e03bf2a6a3 sign 2026-02-17 13:30:24 -08:00
Jessica Singh
7c8c7c9d91 add anon tests 2026-02-17 11:34:17 -08:00
Jessica Singh
89d8521f37 changes 2026-02-16 23:16:41 -08:00
Jessica Singh
24a0e08ee2 changes 2026-02-16 23:15:14 -08:00
1064 changed files with 17939 additions and 61751 deletions

View File

@@ -1 +0,0 @@
../.cursor/skills

View File

@@ -1,248 +0,0 @@
---
name: playwright-e2e-tests
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
---
# Playwright E2E Tests
## Project Layout
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
- **Config**: `web/playwright.config.ts`
- **Utilities**: `web/tests/e2e/utils/`
- **Constants**: `web/tests/e2e/constants.ts`
- **Global setup**: `web/tests/e2e/global-setup.ts`
- **Output**: `web/output/playwright/`
## Imports
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
```
All new files should be `.ts`, not `.js`.
## Running Tests
```bash
# Run a specific test file
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
# Run a specific project
npx playwright test --project admin
npx playwright test --project exclusive
```
## Test Projects
| Project | Description | Parallelism |
|---------|-------------|-------------|
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
## Authentication
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
- Server readiness check (polls health endpoint, 60s timeout)
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
- Setting display name to `"worker"` for each worker user
- Promoting admin2 to admin role
- Ensuring a public LLM provider exists
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
When a test needs a different user, use API-based login — never drive the login UI:
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAs(page, "admin2");
// Log in as the worker-specific user (preferred for test isolation):
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
```
## Test Structure
Tests start pre-authenticated as admin — navigate and test directly:
```typescript
import { test, expect } from "@playwright/test";
test.describe("Feature Name", () => {
test("should describe expected behavior clearly", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Already authenticated as admin — go straight to testing
});
});
```
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
```typescript
import { test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
});
```
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
## Key Utilities
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
Backend API client for test setup/teardown. Key methods:
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
- **Chat**: `createChatSession()`, `deleteChatSession()`
### `chatActions` (`@tests/e2e/utils/chatActions`)
- `sendMessage(page, message)` — sends a message and waits for AI response
- `startNewChat(page)` — clicks new-chat button and waits for intro
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
- `switchModel(page, modelName)` — switches LLM model via popover
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
- `expectElementScreenshot(locator, { name, mask?, hide? })`
- Controlled by `VISUAL_REGRESSION=true` env var
### `theme` (`@tests/e2e/utils/theme`)
- `THEMES``["light", "dark"] as const` array for iterating over both themes
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
```typescript
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
for (const theme of THEMES) {
test.describe(`Feature (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("renders correctly", async ({ page }) => {
await page.goto("/app");
await expectScreenshot(page, { name: `feature-${theme}` });
});
});
}
```
### `tools` (`@tests/e2e/utils/tools`)
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
- `openActionManagement(page)` — opens the tool management popover
## Locator Strategy
Use locators in this priority order:
1. **`data-testid` / `aria-label`** — preferred for Onyx components
```typescript
page.getByTestId("AppSidebar/new-session")
page.getByLabel("admin-page-title")
```
2. **Role-based** — for standard HTML elements
```typescript
page.getByRole("button", { name: "Create" })
page.getByRole("dialog")
```
3. **Text/Label** — for visible text content
```typescript
page.getByText("Custom Assistant")
page.getByLabel("Email")
```
4. **CSS selectors** — last resort, only when above won't work
```typescript
page.locator('input[name="name"]')
page.locator("#onyx-chat-input-textarea")
```
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
## Assertions
Use web-first assertions — they auto-retry until the condition is met:
```typescript
// Visibility
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
// Text content
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
// Count
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
// URL
await expect(page).toHaveURL(/chatId=/);
// Element state
await expect(toggle).toBeChecked();
await expect(button).toBeEnabled();
```
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
## Waiting Strategy
```typescript
// Wait for load state after navigation
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Wait for specific element
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
// Wait for URL change
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
// Wait for network response
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
```
## Best Practices
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does

View File

@@ -1,73 +0,0 @@
name: "Build Backend Image"
description: "Builds and pushes the backend Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
docker-no-cache:
description: "Set to 'true' to disable docker build cache"
required: false
default: "false"
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
no-cache: ${{ inputs.docker-no-cache == 'true' }}

View File

@@ -1,76 +0,0 @@
name: "Build Integration Image"
description: "Builds and pushes the integration test image with docker bake"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Build and push integration test image with Docker Bake
shell: bash
env:
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
TAG: nightly-llm-it-${{ inputs.run-id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ inputs.github-sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration

View File

@@ -1,68 +0,0 @@
name: "Build Model Server Image"
description: "Builds and pushes the model server Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max

View File

@@ -1,130 +0,0 @@
name: "Run Nightly Provider Chat Test"
description: "Starts required compose services and runs nightly provider integration test"
inputs:
provider:
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
required: true
models:
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: false
default: ""
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
api-base:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
api-version:
description: "Optional NIGHTLY_LLM_API_VERSION"
required: false
default: ""
deployment-name:
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
default: ""
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
run-id:
description: "GitHub run ID used in image tags"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Create .env file for Docker Compose
shell: bash
env:
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
run: |
cat <<EOF2 > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AWS_REGION_NAME=us-west-2
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
- name: Start Docker containers
shell: bash
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server
- name: Run nightly provider integration test
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
env:
MODELS: ${{ inputs.models }}
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
with:
timeout_minutes: 20
max_attempts: 2
retry_wait_seconds: 10
command: |
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e AWS_REGION_NAME=us-west-2 \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Override Linear Check

View File

@@ -426,9 +426,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
@@ -500,9 +499,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
@@ -648,8 +646,8 @@ jobs:
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
@@ -730,8 +728,8 @@ jobs:
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
@@ -864,9 +862,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
@@ -937,9 +934,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
@@ -1076,8 +1072,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
ENABLE_CRAFT=true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
@@ -1149,8 +1145,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
ENABLE_CRAFT=true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
@@ -1291,9 +1287,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
@@ -1371,9 +1366,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -1,50 +0,0 @@
name: Nightly LLM Provider Chat Tests
concurrency:
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
secrets: inherit
permissions:
contents: read
id-token: write
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
strict: true
notify-slack-on-failure:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: provider-chat-test
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -0,0 +1,151 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -1,163 +0,0 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -0,0 +1,28 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,9 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -126,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \
@@ -160,7 +164,7 @@ jobs:
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
# Collect logs from each container
for container in $containers; do

View File

@@ -41,7 +41,8 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
with:
uv_version: "0.9.9"
@@ -91,7 +92,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -20,7 +20,6 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -335,6 +334,7 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
@@ -423,7 +423,6 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -444,7 +443,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -470,13 +468,13 @@ jobs:
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
onyx-lite-tests:
no-vectordb-tests:
needs: [build-backend-image, build-integration-image]
runs-on:
[
runs-on,
runner=4cpu-linux-arm64,
"run-id=${{ github.run_id }}-onyx-lite-tests",
"run-id=${{ github.run_id }}-no-vectordb-tests",
"extras=ecr-cache",
]
timeout-minutes: 45
@@ -494,12 +492,13 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create .env file for Onyx Lite Docker Compose
- name: Create .env file for no-vectordb Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
@@ -507,23 +506,28 @@ jobs:
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
DISABLE_VECTOR_DB=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
EOF
# Start only the services needed for Onyx Lite (Postgres + API server)
- name: Start Docker containers (onyx-lite)
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
- name: Start Docker containers (no-vectordb)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
relational_db \
cache \
minio \
api_server \
background \
-d
id: start_docker_onyx_lite
id: start_docker_no_vectordb
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script (onyx-lite)..."
echo "Starting wait-for-service script (no-vectordb)..."
start_time=$(date +%s)
timeout=300
while true; do
@@ -545,14 +549,14 @@ jobs:
sleep 5
done
- name: Run Onyx Lite Integration Tests
- name: Run No-VectorDB Integration Tests
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running onyx-lite integration tests..."
echo "Running no-vectordb integration tests..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -563,38 +567,39 @@ jobs:
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e TEST_WEB_HOSTNAME=test-runner \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/tests/no_vectordb
- name: Dump API server logs (onyx-lite)
- name: Dump API server logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
- name: Dump all-container logs (onyx-lite)
- name: Dump all-container logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
- name: Upload logs (onyx-lite)
- name: Upload logs (no-vectordb)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-onyx-lite
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
name: docker-all-logs-no-vectordb
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
- name: Stop Docker containers (onyx-lite)
- name: Stop Docker containers (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
multitenant-tests:
needs:
@@ -696,7 +701,6 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
@@ -736,7 +740,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -22,9 +22,6 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
@@ -268,11 +265,10 @@ jobs:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
@@ -280,7 +276,6 @@ jobs:
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
@@ -305,7 +300,6 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
@@ -592,122 +586,17 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-tests-lite:
needs: [build-web-image, build-backend-image]
name: Playwright Tests (lite)
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-playwright-tests-lite"
- "extras=ecr-cache"
timeout-minutes: 30
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-npm-
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
MOCK_LLM_RESPONSE=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers (lite)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
id: start_docker
- name: Run Playwright tests (lite)
working-directory: ./web
run: npx playwright test --project lite
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-test-results-lite-${{ github.run_id }}
path: ./web/output/playwright/
retention-days: 30
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-lite-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: >-
always() &&
github.event_name == 'pull_request' &&
needs.playwright-tests.result != 'cancelled'
if: always() && github.event_name == 'pull_request'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
with:
pattern: screenshot-diff-summary-*
path: summaries/
@@ -790,7 +679,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests, playwright-tests-lite]
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -8,7 +8,7 @@ on:
pull_request:
branches:
- main
- "release/**"
- 'release/**'
push:
tags:
- "v*.*.*"
@@ -21,13 +21,7 @@ jobs:
# See https://runs-on.com/runners/linux/
# Note: Mypy seems quite optimized for x64 compared to arm64.
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
runs-on:
[
runs-on,
runner=2cpu-linux-x64,
"run-id=${{ github.run_id }}-mypy-check",
"extras=s3-cache",
]
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
timeout-minutes: 45
steps:
@@ -58,14 +52,21 @@ jobs:
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy
working-directory: ./backend
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy .
- name: Run MyPy (tools/)
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy tools/

View File

@@ -89,10 +89,6 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}

View File

@@ -1,329 +0,0 @@
name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
openai_models:
description: "Comma-separated models for openai"
required: false
default: ""
type: string
anthropic_models:
description: "Comma-separated models for anthropic"
required: false
default: ""
type: string
bedrock_models:
description: "Comma-separated models for bedrock"
required: false
default: ""
type: string
vertex_ai_models:
description: "Comma-separated models for vertex_ai"
required: false
default: ""
type: string
azure_models:
description: "Comma-separated models for azure"
required: false
default: ""
type: string
ollama_models:
description: "Comma-separated models for ollama_chat"
required: false
default: ""
type: string
openrouter_models:
description: "Comma-separated models for openrouter"
required: false
default: ""
type: string
azure_api_base:
description: "API base for azure provider"
required: false
default: ""
type: string
strict:
description: "Default NIGHTLY_LLM_STRICT passed to tests"
required: false
default: true
type: boolean
permissions:
contents: read
id-token: write
jobs:
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
provider-chat-test:
needs:
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
strategy:
fail-fast: false
matrix:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_env: OPENAI_API_KEY
custom_config_env: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_env: ANTHROPIC_API_KEY
custom_config_env: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_env: BEDROCK_API_KEY
custom_config_env: ""
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_env: ""
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: azure
models: ${{ inputs.azure_models }}
api_key_env: AZURE_API_KEY
custom_config_env: ""
api_base: ${{ inputs.azure_api_base }}
api_version: "2025-04-01-preview"
deployment_name: ""
required: false
- provider: ollama_chat
models: ${{ inputs.ollama_models }}
api_key_env: OLLAMA_API_KEY
custom_config_env: ""
api_base: "https://ollama.com"
api_version: ""
deployment_name: ""
required: false
- provider: openrouter
models: ${{ inputs.openrouter_models }}
api_key_env: OPENROUTER_API_KEY
custom_config_env: ""
api_base: "https://openrouter.ai/api/v1"
api_version: ""
deployment_name: ""
required: false
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
parse-json-secrets: false
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
OPENAI_API_KEY, test/openai-api-key
ANTHROPIC_API_KEY, test/anthropic-api-key
BEDROCK_API_KEY, test/bedrock-api-key
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
AZURE_API_KEY, test/azure-api-key
OLLAMA_API_KEY, test/ollama-api-key
OPENROUTER_API_KEY, test/openrouter-api-key
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
api-base: ${{ matrix.api_base }}
api-version: ${{ matrix.api_version }}
deployment-name: ${{ matrix.deployment_name }}
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose down -v

View File

@@ -5,8 +5,6 @@ on:
branches: ["main"]
pull_request:
branches: ["**"]
paths:
- ".github/**"
permissions: {}
@@ -23,18 +21,29 @@ jobs:
with:
persist-credentials: false
- name: Detect changes
id: filter
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
with:
filters: |
zizmor:
- '.github/**'
- name: Install the latest version of uv
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Run zizmor
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif

1
.gitignore vendored
View File

@@ -7,7 +7,6 @@
.zed
.cursor
!/.cursor/mcp.json
!/.cursor/skills/
# macos
.DS_store

43
.vscode/launch.json vendored
View File

@@ -40,7 +40,19 @@
}
},
{
"name": "Celery",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
@@ -241,6 +253,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",

View File

@@ -86,6 +86,37 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
@@ -517,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
To run them:
@@ -585,48 +616,3 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Error Handling
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
`{"error_code": "...", "message": "..."}` shape. This eliminates boilerplate and keeps error
handling consistent across the entire backend.
```python
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ✅ Good
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
# ✅ Good — no extra message needed
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
# ✅ Good — upstream service with dynamic status code
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
# ❌ Bad — using HTTPException directly
raise HTTPException(status_code=404, detail="Session not found")
# ❌ Bad — starlette constant
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
```
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
category is needed, add it there first — do not invent ad-hoc codes.
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
status code is dynamic (comes from the upstream response), use `status_code_override`:
```python
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
```
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -21,14 +21,15 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import NamedTuple
from typing import List, NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -104,6 +105,56 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -1,29 +0,0 @@
"""code interpreter seed
Revision ID: 07b98176f1de
Revises: 7cb492013621
Create Date: 2026-02-23 15:55:07.606784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "07b98176f1de"
down_revision = "7cb492013621"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Seed the single instance of code_interpreter_server
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
op.execute(
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
)
def downgrade() -> None:
op.execute(sa.text("DELETE FROM code_interpreter_server"))

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,37 +0,0 @@
"""add cache_store table
Revision ID: 2664261bfaab
Revises: 4a1e4b1c89d2
Create Date: 2026-02-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2664261bfaab"
down_revision = "4a1e4b1c89d2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"cache_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("key"),
)
op.create_index(
"ix_cache_store_expires",
"cache_store",
["expires_at"],
postgresql_where=sa.text("expires_at IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_cache_store_expires", table_name="cache_store")
op.drop_table("cache_store")

View File

@@ -1,51 +0,0 @@
"""Add INDEXING to UserFileStatus
Revision ID: 4a1e4b1c89d2
Revises: 6b3b4083c5aa
Create Date: 2026-02-28 00:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
revision = "4a1e4b1c89d2"
down_revision = "6b3b4083c5aa"
branch_labels = None
depends_on = None
TABLE = "user_file"
COLUMN = "status"
CONSTRAINT_NAME = "ck_user_file_status"
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
def _drop_status_check_constraint() -> None:
"""Drop the existing CHECK constraint on user_file.status.
The constraint name is auto-generated by SQLAlchemy and unknown,
so we look it up via the inspector.
"""
inspector = sa.inspect(op.get_bind())
for constraint in inspector.get_check_constraints(TABLE):
if COLUMN in constraint.get("sqltext", ""):
constraint_name = constraint["name"]
if constraint_name is not None:
op.drop_constraint(constraint_name, TABLE, type_="check")
def upgrade() -> None:
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
def downgrade() -> None:
op.execute(
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
)
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")

View File

@@ -1,69 +0,0 @@
"""add python tool on default
Revision ID: 57122d037335
Revises: c0c937d5c9e5
Create Date: 2026-02-27 10:10:40.124925
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "57122d037335"
down_revision = "c0c937d5c9e5"
branch_labels = None
depends_on = None
PYTHON_TOOL_NAME = "python"
def upgrade() -> None:
conn = op.get_bind()
# Look up the PythonTool id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
tool_id = result[0]
# Attach to the default persona (id=0) if not already attached
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": tool_id},
)
def downgrade() -> None:
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE persona_id = 0 AND tool_id = :tool_id
"""
),
{"tool_id": result[0]},
)

View File

@@ -1,32 +0,0 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -1,112 +0,0 @@
"""persona cleanup and featured
Revision ID: 6b3b4083c5aa
Revises: 57122d037335
Create Date: 2026-02-26 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6b3b4083c5aa"
down_revision = "57122d037335"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add featured column with nullable=True first
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
# Migrate data from is_default_persona to featured
op.execute("UPDATE persona SET featured = is_default_persona")
# Make featured non-nullable with default=False
op.alter_column(
"persona",
"featured",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
# Drop is_default_persona column
op.drop_column("persona", "is_default_persona")
# Drop unused columns
op.drop_column("persona", "num_chunks")
op.drop_column("persona", "chunks_above")
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "llm_relevance_filter")
op.drop_column("persona", "llm_filter_extraction")
op.drop_column("persona", "recency_bias")
def downgrade() -> None:
# Add back recency_bias column
op.add_column(
"persona",
sa.Column(
"recency_bias",
sa.VARCHAR(),
nullable=False,
server_default="base_decay",
),
)
# Add back llm_filter_extraction column
op.add_column(
"persona",
sa.Column(
"llm_filter_extraction",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Add back llm_relevance_filter column
op.add_column(
"persona",
sa.Column(
"llm_relevance_filter",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Add back chunks_below column
op.add_column(
"persona",
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
)
# Add back chunks_above column
op.add_column(
"persona",
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
)
# Add back num_chunks column
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
# Add back is_default_persona column
op.add_column(
"persona",
sa.Column(
"is_default_persona",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Migrate data from featured to is_default_persona
op.execute("UPDATE persona SET is_default_persona = featured")
# Drop featured column
op.drop_column("persona", "featured")

View File

@@ -1,48 +0,0 @@
"""add enterprise and name fields to scim_user_mapping
Revision ID: 7616121f6e97
Revises: 07b98176f1de
Create Date: 2026-02-23 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7616121f6e97"
down_revision = "07b98176f1de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("department", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("manager", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("given_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("family_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("scim_emails_json", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_emails_json")
op.drop_column("scim_user_mapping", "family_name")
op.drop_column("scim_user_mapping", "given_name")
op.drop_column("scim_user_mapping", "manager")
op.drop_column("scim_user_mapping", "department")

View File

@@ -1,31 +0,0 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -1,33 +0,0 @@
"""add needs_persona_sync to user_file
Revision ID: 8ffcc2bcfc11
Revises: 7616121f6e97
Create Date: 2026-02-23 10:48:48.343826
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8ffcc2bcfc11"
down_revision = "7616121f6e97"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user_file",
sa.Column(
"needs_persona_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("user_file", "needs_persona_sync")

View File

@@ -1,34 +0,0 @@
"""make scim_user_mapping.external_id nullable
Revision ID: a3b8d9e2f1c4
Revises: 2664261bfaab
Create Date: 2026-03-02
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "a3b8d9e2f1c4"
down_revision = "2664261bfaab"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=True,
)
def downgrade() -> None:
# Delete any rows where external_id is NULL before re-applying NOT NULL
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=False,
)

View File

@@ -1,70 +0,0 @@
"""llm provider deprecate fields
Revision ID: c0c937d5c9e5
Revises: 8ffcc2bcfc11
Create Date: 2026-02-25 17:35:46.125102
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0c937d5c9e5"
down_revision = "8ffcc2bcfc11"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Make default_model_name nullable (was NOT NULL)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=True,
)
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
op.drop_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
type_="unique",
)
# Remove server_default from is_default_vision_provider (was server_default=false())
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=None,
)
def downgrade() -> None:
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
op.execute(
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=False,
)
# Restore unique constraint on is_default_provider
op.create_unique_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
["is_default_provider"],
)
# Restore server_default for is_default_vision_provider
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=sa.false(),
)

View File

@@ -1,31 +0,0 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime
import jwt
@@ -20,7 +21,13 @@ logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")

View File

@@ -0,0 +1,15 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -11,10 +11,11 @@ from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.cache.factory import get_cache_backend
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -141,7 +142,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
"""
Get license metadata from cache.
Get license metadata from Redis cache.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
@@ -149,34 +150,38 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
Returns:
LicenseMetadata if cached, None otherwise
"""
cache = get_cache_backend(tenant_id=tenant_id)
cached = cache.get(LICENSE_METADATA_KEY)
if not cached:
return None
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_replica_client(tenant_id=tenant)
try:
cached_str = (
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
cached = redis_client.get(LICENSE_METADATA_KEY)
if cached:
try:
cached_str: str
if isinstance(cached, bytes):
cached_str = cached.decode("utf-8")
else:
cached_str = str(cached)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
return None
def invalidate_license_cache(tenant_id: str | None = None) -> None:
"""
Invalidate the license metadata cache (not the license itself).
Deletes the cached LicenseMetadata. The actual license in the database
is not affected. Delete is idempotent if the key doesn't exist, this
is a no-op.
This deletes the cached LicenseMetadata from Redis. The actual license
in the database is not affected. Redis delete is idempotent - if the
key doesn't exist, this is a no-op.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
"""
cache = get_cache_backend(tenant_id=tenant_id)
cache.delete(LICENSE_METADATA_KEY)
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
redis_client.delete(LICENSE_METADATA_KEY)
logger.info("License cache invalidated")
@@ -187,7 +192,7 @@ def update_license_cache(
tenant_id: str | None = None,
) -> LicenseMetadata:
"""
Update the cache with license metadata.
Update the Redis cache with license metadata.
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
1. Frontend needs status to show appropriate UI/banners
@@ -206,7 +211,7 @@ def update_license_cache(
from ee.onyx.utils.license import get_license_status
tenant = tenant_id or get_current_tenant_id()
cache = get_cache_backend(tenant_id=tenant_id)
redis_client = get_redis_client(tenant_id=tenant)
used_seats = get_used_seats(tenant)
status = get_license_status(payload, grace_period_end)
@@ -225,7 +230,7 @@ def update_license_cache(
stripe_subscription_id=payload.stripe_subscription_id,
)
cache.set(
redis_client.set(
LICENSE_METADATA_KEY,
metadata.model_dump_json(),
ex=LICENSE_CACHE_TTL_SECONDS,
@@ -258,15 +263,9 @@ def refresh_license_cache(
try:
payload = verify_license_signature(license_record.license_data)
# Derive source from payload: manual licenses lack stripe_customer_id
source: LicenseSource = (
LicenseSource.AUTO_FETCH
if payload.stripe_customer_id
else LicenseSource.MANUAL_UPLOAD
)
return update_license_cache(
payload,
source=source,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:

View File

@@ -1,721 +0,0 @@
"""SCIM Data Access Layer.
All database operations for SCIM provisioning — token management, user
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
Usage from FastAPI::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.post("/tokens")
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
token = dal.create_token(name=..., hashed_token=..., ...)
dal.commit()
return token
Usage from background tasks::
with ScimDAL.from_tenant("tenant_abc") as dal:
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
dal.commit()
"""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ScimDAL(DAL):
"""Data Access Layer for SCIM provisioning operations.
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
when you want to persist changes. This follows the existing ``_no_commit``
convention and lets callers batch multiple operations into one transaction.
"""
# ------------------------------------------------------------------
# Token operations
# ------------------------------------------------------------------
def create_token(
self,
name: str,
hashed_token: str,
token_display: str,
created_by_id: UUID,
) -> ScimToken:
"""Create a new SCIM bearer token.
Only one token is active at a time — this method automatically revokes
all existing active tokens before creating the new one.
"""
# Revoke any currently active tokens
active_tokens = list(
self._session.scalars(
select(ScimToken).where(ScimToken.is_active.is_(True))
).all()
)
for t in active_tokens:
t.is_active = False
token = ScimToken(
name=name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=created_by_id,
)
self._session.add(token)
self._session.flush()
return token
def get_active_token(self) -> ScimToken | None:
"""Return the single currently active token, or None."""
return self._session.scalar(
select(ScimToken).where(ScimToken.is_active.is_(True))
)
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
"""Look up a token by its SHA-256 hash."""
return self._session.scalar(
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
)
def revoke_token(self, token_id: int) -> None:
"""Deactivate a token by ID.
Raises:
ValueError: If the token does not exist.
"""
token = self._session.get(ScimToken, token_id)
if not token:
raise ValueError(f"SCIM token with id {token_id} not found")
token.is_active = False
def update_token_last_used(self, token_id: int) -> None:
"""Update the last_used_at timestamp for a token."""
token = self._session.get(ScimToken, token_id)
if token:
token.last_used_at = func.now() # type: ignore[assignment]
# ------------------------------------------------------------------
# User mapping operations
# ------------------------------------------------------------------
def create_user_mapping(
self,
external_id: str | None,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a SCIM mapping for a user.
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
allows this). The mapping still marks the user as SCIM-managed.
"""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
department=f.department,
manager=f.manager,
given_name=f.given_name,
family_name=f.family_name,
scim_emails_json=f.scim_emails_json,
)
self._session.add(mapping)
self._session.flush()
return mapping
def get_user_mapping_by_external_id(
self, external_id: str
) -> ScimUserMapping | None:
"""Look up a user mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
)
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
"""Look up a user mapping by the Onyx user ID."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
)
def list_user_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimUserMapping], int]:
"""List user mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimUserMapping)
.order_by(ScimUserMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def update_user_mapping_external_id(
self,
mapping_id: int,
external_id: str,
) -> ScimUserMapping:
"""Update the external ID on a user mapping.
Raises:
ValueError: If the mapping does not exist.
"""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
mapping.external_id = external_id
return mapping
def delete_user_mapping(self, mapping_id: int) -> None:
"""Delete a user mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# User query operations
# ------------------------------------------------------------------
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
"""Fetch a user by email (case-insensitive)."""
return self._session.scalar(
select(User).where(func.lower(User.email) == func.lower(email))
)
def add_user(self, user: User) -> None:
"""Add a new user to the session and flush to assign an ID."""
self._session.add(user)
self._session.flush()
def update_user(
self,
user: User,
*,
email: str | None = None,
is_active: bool | None = None,
personal_name: str | None = None,
) -> None:
"""Update user attributes. Only sets fields that are provided."""
if email is not None:
user.email = email
if is_active is not None:
user.is_active = is_active
if personal_name is not None:
user.personal_name = personal_name
def deactivate_user(self, user: User) -> None:
"""Mark a user as inactive."""
user.is_active = False
def list_users(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, mapping) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
# unless they were explicitly linked via SCIM provisioning.
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
)
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
# Count total matching rows first, then paginate. SCIM uses 1-based
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
)
.unique()
.all()
)
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Sync the SCIM mapping for a user.
If a mapping already exists, its fields are updated (including
setting ``external_id`` to ``None`` when the IdP omits it).
If no mapping exists and ``new_external_id`` is provided, a new
mapping is created. A mapping is never deleted here — SCIM-managed
users must retain their mapping to remain visible in ``GET /Users``.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
elif new_external_id:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
# ------------------------------------------------------------------
# Group mapping operations
# ------------------------------------------------------------------
def create_group_mapping(
self,
external_id: str,
user_group_id: int,
) -> ScimGroupMapping:
"""Create a mapping between a SCIM externalId and an Onyx user group."""
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_group_mapping_by_external_id(
self, external_id: str
) -> ScimGroupMapping | None:
"""Look up a group mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
)
def get_group_mapping_by_group_id(
self, user_group_id: int
) -> ScimGroupMapping | None:
"""Look up a group mapping by the Onyx user group ID."""
return self._session.scalar(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id == user_group_id
)
)
def list_group_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimGroupMapping], int]:
"""List group mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimGroupMapping)
.order_by(ScimGroupMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def delete_group_mapping(self, mapping_id: int) -> None:
"""Delete a group mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group query operations
# ------------------------------------------------------------------
def get_group(self, group_id: int) -> UserGroup | None:
"""Fetch a group by ID, returning None if deleted or missing."""
group = self._session.get(UserGroup, group_id)
if group and group.is_up_for_deletion:
return None
return group
def get_group_by_name(self, name: str) -> UserGroup | None:
"""Fetch a group by exact name."""
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
def add_group(self, group: UserGroup) -> None:
"""Add a new group to the session and flush to assign an ID."""
self._session.add(group)
self._session.flush()
def update_group(
self,
group: UserGroup,
*,
name: str | None = None,
) -> None:
"""Update group attributes and set the modification timestamp."""
if name is not None:
group.name = name
group.time_last_modified_by_user = func.now()
def delete_group(self, group: UserGroup) -> None:
"""Delete a group from the session."""
self._session.delete(group)
def list_groups(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[UserGroup, str | None]], int]:
"""Query groups with optional SCIM filter and pagination.
Returns:
A tuple of (list of (group, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(UserGroup.id == mapping.user_group_id)
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
groups = list(
self._session.scalars(
query.order_by(UserGroup.id).offset(offset).limit(count)
).all()
)
ext_id_map = self._get_group_external_ids([g.id for g in groups])
return [(g, ext_id_map.get(g.id)) for g in groups], total
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
"""Get group members as (user_id, email) pairs."""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
).all()
user_ids = [r.user_id for r in rels if r.user_id]
if not user_ids:
return []
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users_by_id = {u.id: u for u in users}
return [
(
r.user_id,
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
)
for r in rels
if r.user_id
]
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
"""Return the subset of UUIDs that don't exist as users.
Returns an empty list if all IDs are valid.
"""
if not uuids:
return []
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Add user-group relationships, ignoring duplicates."""
if not user_ids:
return
self._session.execute(
pg_insert(User__UserGroup)
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
.on_conflict_do_nothing(
index_elements=[
User__UserGroup.user_group_id,
User__UserGroup.user_id,
]
)
)
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Replace all members of a group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
)
self.upsert_group_members(group_id, user_ids)
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Remove specific members from a group."""
if not user_ids:
return
self._session.execute(
sa_delete(User__UserGroup).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.in_(user_ids),
)
)
def delete_group_with_members(self, group: UserGroup) -> None:
"""Remove all member relationships and delete the group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
)
self._session.delete(group)
def sync_group_external_id(
self, group_id: int, new_external_id: str | None
) -> None:
"""Create, update, or delete the external ID mapping for a group."""
mapping = self.get_group_mapping_by_group_id(group_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_group_mapping(
external_id=new_external_id, user_group_id=group_id
)
elif mapping:
self.delete_group_mapping(mapping.id)
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
"""Batch-fetch external IDs for a list of group IDs."""
if not group_ids:
return {}
mappings = self._session.scalars(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id.in_(group_ids)
)
).all()
return {m.user_group_id: m.external_id for m in mappings}
# ---------------------------------------------------------------------------
# Module-level helpers (used by DAL methods above)
# ---------------------------------------------------------------------------
def _apply_scim_string_op(
query: Select[tuple[User]] | Select[tuple[UserGroup]],
column: SQLColumnExpression[str],
scim_filter: ScimFilter,
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
SQLAlchemy's operators handle LIKE-pattern escaping internally.
"""
val = scim_filter.value
if scim_filter.operator == ScimFilterOperator.EQUAL:
return query.where(func.lower(column) == val.lower())
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
return query.where(column.icontains(val, autoescape=True))
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
return query.where(column.istartswith(val, autoescape=True))
else:
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")

View File

@@ -9,26 +9,20 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -201,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -267,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -276,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -295,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(
@@ -472,9 +405,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name,
time_last_modified_by_user=func.now(),
is_up_to_date=DISABLE_VECTOR_DB,
name=user_group.name, time_last_modified_by_user=func.now()
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -777,7 +708,8 @@ def update_user_group(
cc_pair_ids=user_group_update.cc_pair_ids,
)
if cc_pairs_updated and not DISABLE_VECTOR_DB:
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(

View File

@@ -1,13 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -47,27 +43,14 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,12 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -17,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -39,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -597,12 +527,8 @@ def get_external_access_from_sharepoint(
)
elif site_page:
site_url = site_page.get("webUrl")
# Keep percent-encoding intact so the path matches the encoding
# used by the Office365 library's SPResPath.create_relative(),
# which compares against urlparse(context.base_url).path.
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
# the site prefix in the constructed URL.
server_relative_url = urlparse(site_url).path
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
@@ -646,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -760,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -31,8 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import register_scim_exception_handlers
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -164,12 +162,6 @@ def get_application() -> FastAPI:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
register_scim_exception_handlers(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
# before bearer token configuration is complete
("/scim/v2/ServiceProviderConfig", {"GET"}),
("/scim/v2/ResourceTypes", {"GET"}),
("/scim/v2/Schemas", {"GET"}),
# needs to be accessible prior to user login
("/enterprise-settings", {"GET"}),
("/enterprise-settings/logo", {"GET"}),

View File

@@ -26,6 +26,7 @@ import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -41,6 +42,7 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import StripePublishableKeyResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.billing.service import BillingServiceError
from ee.onyx.server.billing.service import (
create_checkout_session as create_checkout_service,
)
@@ -56,8 +58,6 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.engine.sql_engine import get_session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -169,23 +169,26 @@ async def create_checkout_session(
if seats is not None:
used_seats = get_used_seats(tenant_id)
if seats < used_seats:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Cannot subscribe with fewer seats than current usage. "
raise HTTPException(
status_code=400,
detail=f"Cannot subscribe with fewer seats than current usage. "
f"You have {used_seats} active users/integrations but requested {seats} seats.",
)
# Build redirect URL for after checkout completion
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
return await create_checkout_service(
billing_period=billing_period,
seats=seats,
email=email,
license_data=license_data,
redirect_url=redirect_url,
tenant_id=tenant_id,
)
try:
return await create_checkout_service(
billing_period=billing_period,
seats=seats,
email=email,
license_data=license_data,
redirect_url=redirect_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/create-customer-portal-session")
@@ -203,15 +206,18 @@ async def create_customer_portal_session(
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
raise HTTPException(status_code=400, detail="No license found")
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
return await create_portal_service(
license_data=license_data,
return_url=return_url,
tenant_id=tenant_id,
)
try:
return await create_portal_service(
license_data=license_data,
return_url=return_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/billing-information")
@@ -234,9 +240,9 @@ async def get_billing_information(
# Check circuit breaker (self-hosted only)
if _is_billing_circuit_open():
raise OnyxError(
OnyxErrorCode.SERVICE_UNAVAILABLE,
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
raise HTTPException(
status_code=503,
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
)
try:
@@ -244,11 +250,11 @@ async def get_billing_information(
license_data=license_data,
tenant_id=tenant_id,
)
except OnyxError as e:
except BillingServiceError as e:
# Open circuit breaker on connection failures (self-hosted only)
if e.status_code in (502, 503, 504):
_open_billing_circuit()
raise
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/seats/update")
@@ -268,25 +274,31 @@ async def update_seats(
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
raise HTTPException(status_code=400, detail="No license found")
# Validate that new seat count is not less than current used seats
used_seats = get_used_seats(tenant_id)
if request.new_seat_count < used_seats:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Cannot reduce seats below current usage. "
raise HTTPException(
status_code=400,
detail=f"Cannot reduce seats below current usage. "
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
)
# Note: Don't store license here - the control plane may still be processing
# the subscription update. The frontend should call /license/claim after a
# short delay to get the freshly generated license.
return await update_seat_service(
new_seat_count=request.new_seat_count,
license_data=license_data,
tenant_id=tenant_id,
)
try:
result = await update_seat_service(
new_seat_count=request.new_seat_count,
license_data=license_data,
tenant_id=tenant_id,
)
# Note: Don't store license here - the control plane may still be processing
# the subscription update. The frontend should call /license/claim after a
# short delay to get the freshly generated license.
return result
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/stripe-publishable-key")
@@ -317,18 +329,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Stripe publishable key is not configured",
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
@@ -339,17 +351,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Validate key format
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to fetch Stripe publishable key",
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -22,8 +22,6 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -33,6 +31,15 @@ logger = setup_logger()
_REQUEST_TIMEOUT = 30.0
class BillingServiceError(Exception):
"""Exception raised for billing service errors."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(self.message)
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
"""Build headers for proxy requests (self-hosted).
@@ -94,7 +101,7 @@ async def _make_billing_request(
Response JSON as dict
Raises:
OnyxError: If request fails
BillingServiceError: If request fails
"""
base_url = _get_base_url()
@@ -121,17 +128,11 @@ async def _make_billing_request(
except Exception:
pass
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail,
status_code_override=e.response.status_code,
)
raise BillingServiceError(detail, e.response.status_code)
except httpx.RequestError:
logger.exception("Failed to connect to billing service")
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
)
raise BillingServiceError("Failed to connect to billing service", 502)
async def create_checkout_session(

View File

@@ -13,7 +13,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
@@ -203,73 +198,3 @@ def upload_custom_analytics_script(
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()
# ---------------------------------------------------------------------------
# SCIM token management
# ---------------------------------------------------------------------------
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
# Derive the IdP domain from the first synced user as a heuristic.
idp_domain: str | None = None
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
if mappings:
user = dal.get_user(mappings[0].user_id)
if user and "@" in user.email:
idp_domain = user.email.rsplit("@", 1)[1]
return ScimTokenResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
idp_domain=idp_domain,
)
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.
Only one token is active at a time — creating a new token automatically
revokes all previous tokens. The raw token value is returned exactly once
in the response; it cannot be retrieved again.
"""
raw_token, hashed_token, token_display = generate_scim_token()
token = dal.create_token(
name=body.name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=user.id,
)
dal.commit()
return ScimTokenCreatedResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
raw_token=raw_token,
)

View File

@@ -14,6 +14,7 @@ import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
@@ -34,8 +35,6 @@ from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -128,9 +127,9 @@ async def claim_license(
2. Without session_id: Re-claim using existing license for auth
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"License claiming is only available for self-hosted deployments",
raise HTTPException(
status_code=400,
detail="License claiming is only available for self-hosted deployments",
)
try:
@@ -147,16 +146,15 @@ async def claim_license(
# Re-claim using existing license for auth
metadata = get_license_metadata(db_session)
if not metadata or not metadata.tenant_id:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No license found. Provide session_id after checkout.",
raise HTTPException(
status_code=400,
detail="No license found. Provide session_id after checkout.",
)
license_row = get_license(db_session)
if not license_row or not license_row.license_data:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No license found in database",
raise HTTPException(
status_code=400, detail="No license found in database"
)
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
@@ -175,7 +173,7 @@ async def claim_license(
license_data = data.get("license")
if not license_data:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
raise HTTPException(status_code=404, detail="No license in response")
# Verify signature before persisting
payload = verify_license_signature(license_data)
@@ -201,14 +199,12 @@ async def claim_license(
detail = error_data.get("detail", detail)
except Exception:
pass
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
)
raise HTTPException(status_code=status_code, detail=detail)
except ValueError as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
raise HTTPException(
status_code=502, detail="Failed to connect to license server"
)
@@ -225,9 +221,9 @@ async def upload_license(
The license file must be cryptographically signed by Onyx.
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"License upload is only available for self-hosted deployments",
raise HTTPException(
status_code=400,
detail="License upload is only available for self-hosted deployments",
)
try:
@@ -238,14 +234,14 @@ async def upload_license(
# Remove any stray whitespace/newlines from user input
license_data = license_data.strip()
except UnicodeDecodeError:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
raise HTTPException(status_code=400, detail="Invalid license file format")
# Verify cryptographic signature - this is the only validation needed
# The license's tenant_id identifies the customer in control plane, not locally
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
# Persist to DB and update cache
upsert_license(db_session, license_data)
@@ -301,9 +297,9 @@ async def delete_license(
Admin only - removes license from database and invalidates cache.
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"License deletion is only available for self-hosted deployments",
raise HTTPException(
status_code=400,
detail="License deletion is only available for self-hosted deployments",
)
try:

View File

@@ -46,6 +46,7 @@ from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
@@ -55,7 +56,6 @@ from ee.onyx.configs.license_enforcement_config import (
)
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from shared_configs.contextvars import get_current_tenant_id
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
"[license_enforcement] No license, allowing community features"
)
is_gated = False
except CACHE_TRANSIENT_ERRORS as e:
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to cache connectivity issues
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
if is_gated:

View File

@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 30
num_hits: int = 50
include_content: bool = False
stream: bool = False

File diff suppressed because it is too large Load Diff

View File

@@ -1,109 +0,0 @@
"""SCIM bearer token authentication.
SCIM endpoints are authenticated via bearer tokens that admins create in the
Onyx UI. This module provides:
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
validates the token from the Authorization header.
- ``generate_scim_token``: Creates a new cryptographically random token
and returns the raw value, its SHA-256 hash, and a display suffix.
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
URL-safe base64 from ``secrets.token_urlsafe``.
The hash is stored in the ``scim_token`` table; the raw value is shown to
the admin exactly once at creation time.
"""
import hashlib
import secrets
from fastapi import Depends
from fastapi import Request
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
class ScimAuthError(Exception):
"""Raised when SCIM bearer token authentication fails.
Unlike HTTPException, this carries the status and detail so the SCIM
exception handler can wrap them in an RFC 7644 §3.12 error envelope
with ``schemas`` and ``status`` fields.
"""
def __init__(self, status_code: int, detail: str) -> None:
self.status_code = status_code
self.detail = detail
super().__init__(detail)
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
def _hash_scim_token(token: str) -> str:
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def generate_scim_token() -> tuple[str, str, str]:
"""Generate a new SCIM bearer token.
Returns:
A tuple of ``(raw_token, hashed_token, token_display)`` where
``token_display`` is a masked version showing only the last 4 chars.
"""
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
hashed_token = _hash_scim_token(raw_token)
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
return raw_token, hashed_token, token_display
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
"""Extract and hash a SCIM token from the request Authorization header."""
return get_hashed_bearer_token_from_request(
request,
valid_prefixes=[SCIM_TOKEN_PREFIX],
hash_fn=_hash_scim_token,
)
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
def verify_scim_token(
request: Request,
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimToken:
"""FastAPI dependency that authenticates SCIM requests.
Extracts the bearer token from the Authorization header, hashes it,
looks it up in the database, and verifies it is active.
Note:
This dependency does NOT update ``last_used_at`` — the endpoint
should do that via ``ScimDAL.update_token_last_used()`` so the
timestamp write is part of the endpoint's transaction.
Raises:
HTTPException(401): If the token is missing, invalid, or inactive.
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
token = dal.get_token_by_hash(hashed)
if not token:
raise ScimAuthError(401, "Invalid SCIM bearer token")
if not token.is_active:
raise ScimAuthError(401, "SCIM token has been revoked")
return token

View File

@@ -7,14 +7,12 @@ SCIM protocol schemas follow the wire format defined in:
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
# ---------------------------------------------------------------------------
@@ -32,10 +30,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
SCIM_ENTERPRISE_USER_SCHEMA = (
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
# ---------------------------------------------------------------------------
@@ -68,43 +62,6 @@ class ScimMeta(BaseModel):
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimManagerRef(BaseModel):
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
value: str | None = None
class ScimEnterpriseExtension(BaseModel):
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
department: str | None = None
manager: ScimManagerRef | None = None
@dataclass
class ScimMappingFields:
"""Stored SCIM mapping fields that need to round-trip through the IdP.
Entra ID sends structured name components, email metadata, and enterprise
extension attributes that must be returned verbatim in subsequent GET
responses. These fields are persisted on ScimUserMapping and threaded
through the DAL, provider, and endpoint layers.
"""
department: str | None = None
manager: str | None = None
given_name: str | None = None
family_name: str | None = None
scim_emails_json: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -113,22 +70,14 @@ class ScimUserResource(BaseModel):
to match the SCIM wire format (not Python convention).
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
enterprise_extension: ScimEnterpriseExtension | None = Field(
default=None,
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
)
class ScimGroupMember(BaseModel):
@@ -171,53 +120,12 @@ class ScimPatchOperationType(str, Enum):
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: ScimPatchValue = None
@field_validator("op", mode="before")
@classmethod
def normalize_operation(cls, v: object) -> object:
"""Normalize op to lowercase for case-insensitive matching.
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
instead of ``"replace"``. This is safe for all providers since the
enum values are lowercase. If a future provider requires other
pre-processing quirks, move patch deserialization into the provider
subclass instead of adding more special cases here.
"""
return v.lower() if isinstance(v, str) else v
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
class ScimPatchRequest(BaseModel):
@@ -287,39 +195,10 @@ class ScimServiceProviderConfig(BaseModel):
)
class ScimSchemaAttribute(BaseModel):
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
name: str
type: str
multiValued: bool = False
required: bool = False
description: str = ""
caseExact: bool = False
mutability: str = "readWrite"
returned: str = "default"
uniqueness: str = "none"
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
class ScimSchemaDefinition(BaseModel):
"""SCIM Schema definition (RFC 7643 §7).
Served at GET /scim/v2/Schemas. Describes the attributes available
on each resource type so IdPs know which fields they can provision.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
id: str
name: str
description: str
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
model_config = ConfigDict(populate_by_name=True)
schema_: str = Field(alias="schema")
required: bool
@@ -332,7 +211,7 @@ class ScimResourceType(BaseModel):
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str
@@ -365,7 +244,6 @@ class ScimTokenResponse(BaseModel):
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
idp_domain: str | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):

View File

@@ -14,70 +14,13 @@ responsible for persisting changes.
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
logger = logging.getLogger(__name__)
# Lowercased enterprise extension URN for case-insensitive matching
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
# Pattern for email filter paths, e.g.:
# emails[primary eq true].value (Okta)
# emails[type eq "work"].value (Azure AD / Entra ID)
_EMAIL_FILTER_RE = re.compile(
r"^emails\[.+\]\.value$",
re.IGNORECASE,
)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Dispatch tables for user PATCH paths
#
# Maps lowercased SCIM path → (camelCase key, target dict name).
# "data" writes to the top-level resource dict, "name" writes to the
# name sub-object dict. This replaces the elif chains for simple fields.
# ---------------------------------------------------------------------------
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"active": ("active", "data"),
"username": ("userName", "data"),
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
}
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
"displayname": ("displayName", "data"),
}
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"displayname": ("displayName", "data"),
"externalid": ("externalId", "data"),
}
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
@@ -88,223 +31,94 @@ class ScimPatchError(Exception):
super().__init__(detail)
@dataclass
class _UserPatchCtx:
"""Bundles the mutable state for user PATCH operations."""
data: dict[str, Any]
name_data: dict[str, Any]
ent_data: dict[str, str | None] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# User PATCH
# ---------------------------------------------------------------------------
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimUserResource, dict[str, str | None]]:
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified user resource, enterprise extension data dict).
The enterprise dict has keys ``"department"`` and ``"manager"``
with values set only when a PATCH operation touched them.
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
name_data = data.get("name") or {}
for op in operations:
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
_apply_user_replace(op, ctx, ignored_paths)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_user_remove(op, ctx, ignored_paths)
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
ctx.data["name"] = ctx.name_data
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
data["name"] = name_data
return ScimUserResource.model_validate(data)
def _apply_user_replace(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
data: dict,
name_data: dict,
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set.
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, ctx, ignored_paths)
def _apply_user_remove(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a remove operation to user data — clears the target field."""
path = (op.path or "").lower()
if not path:
raise ScimPatchError("Remove operation requires a path")
if path in ignored_paths:
return
entry = _USER_REMOVE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = None
return
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
_set_user_field(path, op.value, data, name_data)
def _set_user_field(
path: str,
value: ScimPatchValue,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
*,
strict: bool = True,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
) -> None:
"""Set a single field on user data by SCIM path.
Args:
strict: When ``False`` (path-less replace), unknown attributes are
silently skipped. When ``True`` (explicit path), they raise.
"""
if path in ignored_paths:
return
# Simple field writes handled by the dispatch table
entry = _USER_REPLACE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = value
return
# displayName sets both the top-level field and the name.formatted sub-field
if path == "displayname":
ctx.data["displayName"] = value
ctx.name_data["formatted"] = value
elif path == "name":
if isinstance(value, dict):
for k, v in value.items():
ctx.name_data[k] = v
elif path == "emails":
if isinstance(value, list):
ctx.data["emails"] = value
elif _EMAIL_FILTER_RE.match(path):
_update_primary_email(ctx.data, value)
elif path.startswith(_ENTERPRISE_URN_LOWER):
_set_enterprise_field(path, value, ctx.ent_data)
elif not strict:
return
"""Set a single field on user data by SCIM path."""
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
"""Update the primary email entry via an email filter path."""
emails: list[dict] = data.get("emails") or []
for email_entry in emails:
if email_entry.get("primary"):
email_entry["value"] = value
break
else:
emails.append({"value": value, "type": "work", "primary": True})
data["emails"] = emails
def _to_dict(value: ScimPatchValue) -> dict | None:
"""Coerce a SCIM patch value to a plain dict if possible.
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
``extra="allow"``), so we also dump those back to a dict.
"""
if isinstance(value, dict):
return value
if isinstance(value, ScimPatchResourceValue):
return value.model_dump(exclude_unset=True)
return None
def _set_enterprise_field(
path: str,
value: ScimPatchValue,
ent_data: dict[str, str | None],
) -> None:
"""Handle enterprise extension URN paths or value dicts."""
# Full URN as key with dict value (path-less PATCH)
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
if path == _ENTERPRISE_URN_LOWER:
d = _to_dict(value)
if d is not None:
if "department" in d:
ent_data["department"] = d["department"]
if "manager" in d:
mgr = d["manager"]
if isinstance(mgr, dict):
ent_data["manager"] = mgr.get("value")
return
# Dotted URN path, e.g. "urn:...:user:department"
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
if suffix == "department":
ent_data["department"] = str(value) if value is not None else None
elif suffix == "manager":
d = _to_dict(value)
if d is not None:
ent_data["manager"] = d.get("value")
elif isinstance(value, str):
ent_data["manager"] = value
else:
# Unknown enterprise attributes are silently ignored rather than
# rejected — IdPs may send attributes we don't model yet.
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
# ---------------------------------------------------------------------------
# Group PATCH
# ---------------------------------------------------------------------------
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
@@ -319,9 +133,7 @@ def apply_group_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
@@ -342,48 +154,38 @@ def _apply_group_replace(
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if isinstance(op.value, dict):
for key, val in op.value.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data, ignored_paths)
_set_group_field(key.lower(), val, data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
_replace_members(op.value, current_members, added_ids, removed_ids)
return
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
_set_group_field(path, op.value, data)
def _replace_members(
value: list[dict],
value: str | list | dict | bool | None,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
@@ -395,21 +197,16 @@ def _replace_members(
def _set_group_field(
path: str,
value: ScimPatchValue,
value: str | bool | dict | list | None,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
entry = _GROUP_REPLACE_PATHS.get(path)
if entry:
key, _ = entry
data[key] = value
return
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(
@@ -426,10 +223,8 @@ def _apply_group_add(
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in member_dicts:
for member_data in op.value:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)

View File

@@ -1,215 +0,0 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
import json
import logging
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from pydantic import ValidationError
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
logger = logging.getLogger(__name__)
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
{
"id",
"schemas",
"meta",
}
)
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
@property
def user_schemas(self) -> list[str]:
"""Schema URIs to include in User resource responses.
Override in subclasses to advertise additional schemas (e.g. the
enterprise extension for Entra ID).
"""
return [SCIM_USER_SCHEMA]
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
fields: Stored mapping fields that the IdP expects round-tripped.
"""
f = fields or ScimMappingFields()
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
username = scim_username or user.email
# Build enterprise extension when at least one value is present.
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
enterprise_ext: ScimEnterpriseExtension | None = None
schemas = list(self.user_schemas)
if f.department is not None or f.manager is not None:
manager_ref = (
ScimManagerRef(value=f.manager) if f.manager is not None else None
)
enterprise_ext = ScimEnterpriseExtension(
department=f.department,
manager=manager_ref,
)
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
name = self.build_scim_name(user, f)
emails = _deserialize_emails(f.scim_emails_json, username)
resource = ScimUserResource(
schemas=schemas,
id=str(user.id),
externalId=external_id,
userName=username,
name=name,
displayName=user.personal_name,
emails=emails,
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
resource.enterprise_extension = enterprise_ext
return resource
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def build_scim_name(
self,
user: User,
fields: ScimMappingFields,
) -> ScimName:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Always returns a ScimName — Okta's spec tests expect ``name``
(with ``givenName``/``familyName``) on every user resource.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name or "",
familyName=fields.family_name or "",
formatted=user.personal_name or "",
)
if not user.personal_name:
# Derive a reasonable name from the email so that SCIM spec tests
# see non-empty givenName / familyName for every user resource.
local = user.email.split("@")[0] if user.email else ""
return ScimName(givenName=local, familyName="", formatted=local)
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else "",
formatted=user.personal_name,
)
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
"""Deserialize stored email entries or build a default work email."""
if stored_json:
try:
entries = json.loads(stored_json)
if isinstance(entries, list) and entries:
return [ScimEmail(**e) for e in entries]
except (json.JSONDecodeError, TypeError, ValidationError):
logger.warning(
"Corrupt scim_emails_json, falling back to default: %s", stored_json
)
return [ScimEmail(value=username, type="work", primary=True)]
def serialize_emails(emails: list[ScimEmail]) -> str | None:
"""Serialize SCIM email entries to JSON for storage."""
if not emails:
return None
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -1,36 +0,0 @@
"""Entra ID (Azure AD) SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
class EntraProvider(ScimProvider):
"""Entra ID (Azure AD) SCIM provider.
Entra behavioral notes:
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
— handled by ``ScimPatchOperation.normalize_op`` validator.
- Sends the enterprise extension URN as a key in path-less PATCH value
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
store department/manager values.
- Expects the enterprise extension schema in ``schemas`` arrays and
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
"""
@property
def name(self) -> str:
return "entra"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return _ENTRA_IGNORED_PATCH_PATHS
@property
def user_schemas(self) -> list[str]:
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]

View File

@@ -1,26 +0,0 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return COMMON_IGNORED_PATCH_PATHS

View File

@@ -1,173 +0,0 @@
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaAttribute
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "User",
"name": "User",
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
"schemaExtensions": [
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
],
}
)
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "Group",
"name": "Group",
"endpoint": "/scim/v2/Groups",
"description": "SCIM Group resource",
"schema": SCIM_GROUP_SCHEMA,
}
)
USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_USER_SCHEMA,
name="User",
description="SCIM core User schema",
attributes=[
ScimSchemaAttribute(
name="userName",
type="string",
required=True,
uniqueness="server",
description="Unique identifier for the user, typically an email address.",
),
ScimSchemaAttribute(
name="name",
type="complex",
description="The components of the user's name.",
subAttributes=[
ScimSchemaAttribute(
name="givenName",
type="string",
description="The user's first name.",
),
ScimSchemaAttribute(
name="familyName",
type="string",
description="The user's last name.",
),
ScimSchemaAttribute(
name="formatted",
type="string",
description="The full name, including all middle names and titles.",
),
],
),
ScimSchemaAttribute(
name="emails",
type="complex",
multiValued=True,
description="Email addresses for the user.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Email address value.",
),
ScimSchemaAttribute(
name="type",
type="string",
description="Label for this email (e.g. 'work').",
),
ScimSchemaAttribute(
name="primary",
type="boolean",
description="Whether this is the primary email.",
),
],
),
ScimSchemaAttribute(
name="active",
type="boolean",
description="Whether the user account is active.",
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_ENTERPRISE_USER_SCHEMA,
name="EnterpriseUser",
description="Enterprise User extension (RFC 7643 §4.3)",
attributes=[
ScimSchemaAttribute(
name="department",
type="string",
description="Department.",
),
ScimSchemaAttribute(
name="manager",
type="complex",
description="The user's manager.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Manager user ID.",
),
],
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",
description="SCIM core Group schema",
attributes=[
ScimSchemaAttribute(
name="displayName",
type="string",
required=True,
description="Human-readable name for the group.",
),
ScimSchemaAttribute(
name="members",
type="complex",
multiValued=True,
description="Members of the group.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="User ID of the group member.",
),
ScimSchemaAttribute(
name="display",
type="string",
mutability="readOnly",
description="Display name of the group member.",
),
],
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)

View File

@@ -18,8 +18,8 @@ from ee.onyx.server.enterprise_settings.store import (
store_settings as store_ee_settings,
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -117,38 +117,15 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if not llm_upsert_requests:
return
logger.notice("Seeding LLMs")
for request in llm_upsert_requests:
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None
)
if not default_provider:
return
visible_configs = [
mc for mc in default_provider.model_configurations if mc.is_visible
]
default_config = (
visible_configs[0]
if visible_configs
else default_provider.model_configurations[0]
)
update_default_provider(
provider_id=default_provider.id,
model_name=default_config.name,
db_session=db_session,
)
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
@@ -160,6 +137,12 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
@@ -171,7 +154,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=persona.datetime_aware,
featured=persona.featured,
commit=False,
)
db_session.commit()

View File

@@ -1,14 +1,10 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -45,14 +41,6 @@ def check_ee_features_enabled() -> bool:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss — warm from DB so cold-start doesn't block EE features
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(f"Failed to load license from DB: {db_error}")
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
@@ -94,39 +82,21 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss (e.g. after TTL expiry). Fall back to DB so
# the /settings request doesn't falsely return GATED_ACCESS
# while the cache is cold.
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(
f"Failed to load license from DB for settings: {db_error}"
)
if metadata:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
settings.ee_features_enabled = False
elif metadata.used_seats > metadata.seats:
# License is valid but seat limit exceeded
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
settings.seat_count = metadata.seats
settings.used_seats = metadata.used_seats
settings.ee_features_enabled = True
else:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license found in cache or DB.
# No license found.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.
settings.application_status = _BLOCKING_STATUS
settings.ee_features_enabled = False
except CACHE_TRANSIENT_ERRORS as e:
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")
# Fail closed - disable EE features if we can't verify license
settings.ee_features_enabled = False

View File

@@ -21,6 +21,7 @@ import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.server.tenants.access import control_plane_dep
@@ -42,8 +43,6 @@ from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -117,14 +116,9 @@ async def create_customer_portal_session(
try:
portal_url = fetch_customer_portal_session(tenant_id, return_url)
return {"stripe_customer_portal_url": portal_url}
except OnyxError:
raise
except Exception:
except Exception as e:
logger.exception("Failed to create customer portal session")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to create customer portal session",
)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-checkout-session")
@@ -140,14 +134,9 @@ async def create_checkout_session(
try:
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
return {"stripe_checkout_url": checkout_url}
except OnyxError:
raise
except Exception:
except Exception as e:
logger.exception("Failed to create checkout session")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to create checkout session",
)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
@@ -158,20 +147,15 @@ async def create_subscription_session(
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
raise HTTPException(status_code=400, detail="Tenant ID not found")
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
return SubscriptionSessionResponse(sessionId=session_id)
except OnyxError:
raise
except Exception:
except Exception as e:
logger.exception("Failed to create subscription session")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to create subscription session",
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stripe-publishable-key")
@@ -202,18 +186,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Stripe publishable key is not configured",
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
@@ -224,15 +208,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
# Validate key format
if not key.startswith("pk_"):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Invalid Stripe publishable key format",
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to fetch Stripe publishable key",
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -33,7 +33,6 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -303,17 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
def _upsert(request: LLMProviderUpsertRequest) -> None:
nonlocal has_set_default_provider
try:
existing = fetch_existing_llm_provider(
name=request.name, db_session=db_session
)
if existing:
request.id = existing.id
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, default_model, db_session)
update_default_provider(provider.id, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -331,13 +325,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider, default_model_name)
_upsert(openai_provider)
# Create default image generation config using the OpenAI API key
try:
@@ -366,13 +361,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider, default_model_name)
_upsert(anthropic_provider)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -397,13 +393,14 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider, default_model_name)
_upsert(vertexai_provider)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -435,11 +432,12 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider, default_model_name)
_upsert(openrouter_provider)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -5,8 +5,6 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
@@ -22,7 +20,6 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -40,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@@ -156,8 +150,3 @@ def delete_user_group(
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if DISABLE_VECTOR_DB:
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -58,27 +58,16 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
}
response = requests.post(
self.oauth_config.token_url,
data=data,
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -126,26 +115,15 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
"redirect_uri": redirect_uri,
}
response = requests.post(
self.oauth_config.token_url,
data=data,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"redirect_uri": redirect_uri,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -163,13 +141,8 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": OAuthTokenManager._unwrap_sensitive_str(
oauth_config.client_id
),
"client_id": oauth_config.client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
@@ -188,12 +161,6 @@ class OAuthTokenManager:
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False)
return value
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],

View File

@@ -1,4 +1,5 @@
import json
import os
import random
import secrets
import string
@@ -120,7 +121,7 @@ from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.server.settings.store import load_settings
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -137,18 +138,28 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
def verify_auth_setting() -> None:
if AUTH_TYPE == AuthType.CLOUD:
"""Log warnings for AUTH_TYPE issues.
This only runs on app startup not during migrations/scripts.
"""
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "cloud":
raise ValueError(
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
"'cloud' is not a valid auth type for self-hosted deployments."
)
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
@@ -200,45 +211,32 @@ def user_needs_to_be_verified() -> bool:
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
from onyx.cache.factory import get_cache_backend
cache = get_cache_backend(tenant_id=tenant_id)
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
redis_client = get_redis_client(tenant_id=tenant_id)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is None:
return False
assert isinstance(value, bytes)
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
if not workspace_invite_only_enabled():
whitelist = get_invited_users()
if not whitelist:
return
whitelist = get_invited_users()
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -255,13 +253,7 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -277,32 +269,13 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
domain = email.split("@")[-1].lower()
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
)
# Check domain whitelist if configured
@@ -543,7 +516,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
result = await db_session.execute(
select(Persona.id)
.where(
Persona.featured.is_(True),
Persona.is_default_persona.is_(True),
Persona.is_public.is_(True),
Persona.is_visible.is_(True),
Persona.deleted.is_(False),
@@ -725,19 +698,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if user_by_session:
user = user_by_session
# If the user is inactive, check seat availability before
# upgrading role — otherwise they'd become an inactive BASIC
# user who still can't log in.
if not user.is_active:
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)
@@ -1698,10 +1663,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -0,0 +1,142 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -241,7 +241,8 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
"check-for-index-attempt-cleanup",
"check-for-doc-permissions-sync",
"check-for-external-group-sync",
"migrate-chunks-from-vespa-to-opensearch",
"check-for-documents-for-opensearch-migration",
"migrate-documents-from-vespa-to-opensearch",
}
if DISABLE_VECTOR_DB:

View File

@@ -0,0 +1,10 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -41,14 +41,3 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,12 +8,6 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -30,7 +24,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
transform_vespa_chunks_to_opensearch_chunks,
)
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -48,21 +41,13 @@ from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
# shared_task allows this task to be shared across celery app instances.
@@ -91,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token map stored in the
tracked via a continuation token stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -148,27 +129,17 @@ def migrate_chunks_from_vespa_to_opensearch_task(
task_logger.error(err_str)
return False
with (
get_session_with_current_tenant() as db_session,
get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as vespa_client,
):
with get_session_with_current_tenant() as db_session:
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
index_name=search_settings.index_name, tenant_state=tenant_state
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,
tenant_state=tenant_state,
large_chunks_enabled=False,
httpx_client=vespa_client,
)
sanitized_doc_start_time = time.monotonic()
@@ -182,28 +153,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token_map,
continuation_token,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if is_continuation_token_done_for_all_slices(continuation_token_map):
if continuation_token is None and total_chunks_migrated > 0:
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -212,19 +170,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
f"Continuation token: {continuation_token}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token_map = (
raw_vespa_chunks, next_continuation_token = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token map: {next_continuation_token_map}"
f"seconds. Next continuation token: {next_continuation_token}"
)
opensearch_document_chunks, errored_chunks = (
@@ -254,11 +212,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
continuation_token=next_continuation_token,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -22,7 +22,6 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -38,36 +37,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.
@@ -278,7 +247,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -328,7 +296,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -0,0 +1,8 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -5,18 +5,14 @@ from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -25,16 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -65,73 +57,14 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
def enqueue_user_file_project_sync_task(
*,
celery_app: Celery,
redis_client: Redis,
user_file_id: str | UUID,
tenant_id: str,
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
) -> bool:
"""Enqueue a project-sync task if no matching queued task already exists."""
queued_key = _user_file_project_sync_queued_key(user_file_id)
# NX+EX gives us atomic dedupe and a self-healing TTL.
queued_guard_set = redis_client.set(
queued_key,
1,
nx=True,
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
if not queued_guard_set:
return False
try:
celery_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=priority,
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
except Exception:
# Roll back the queued guard if task publish fails.
redis_client.delete(queued_key)
raise
return True
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
@@ -187,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -219,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -246,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -282,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -414,31 +292,28 @@ def _process_user_file_with_indexing(
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
def process_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
"""Core implementation for processing a single user file.
When redis_locking=True, acquires a per-file Redis lock and clears the
queued-key guard (Celery path). When redis_locking=False, skips all Redis
operations (BackgroundTask path).
"""
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
start = time.monotonic()
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
)
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
return None
documents: list[Document] = []
try:
@@ -446,18 +321,15 @@ def process_user_file_impl(
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if not uf:
task_logger.warning(
f"process_user_file_impl - UserFile not found id={user_file_id}"
f"process_single_user_file - UserFile not found id={user_file_id}"
)
return
return None
if uf.status not in (
UserFileStatus.PROCESSING,
UserFileStatus.INDEXING,
):
if uf.status != UserFileStatus.PROCESSING:
task_logger.info(
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
)
return
return None
connector = LocalFileConnector(
file_locations=[uf.file_id],
@@ -471,6 +343,7 @@ def process_user_file_impl(
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
)
# update the document id to userfile id in the documents
for document in documents:
document.id = str(user_file_id)
document.source = DocumentSource.USER_FILE
@@ -492,8 +365,9 @@ def process_user_file_impl(
except Exception as e:
task_logger.exception(
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
# don't update the status if the user file is being deleted
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if (
current_user_file
@@ -502,43 +376,33 @@ def process_user_file_impl(
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return
return None
elapsed = time.monotonic() - start
task_logger.info(
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
)
return None
except Exception as e:
# Attempt to mark the file as failed
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if uf:
# don't update the status if the user file is being deleted
if uf.status != UserFileStatus.DELETING:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
task_logger.exception(
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
raise
return None
finally:
if file_lock is not None and file_lock.owned():
if file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
process_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
soft_time_limit=300,
@@ -589,38 +453,36 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
return None
def delete_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
"""Core implementation for deleting a single user file.
When redis_locking=True, acquires a per-file Redis lock (Celery path).
When redis_locking=False, skips Redis operations (BackgroundTask path).
"""
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
"""Process a single user file delete."""
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
)
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
return None
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"delete_user_file_impl - User file not found id={user_file_id}"
f"process_single_user_file_delete - User file not found id={user_file_id}"
)
return
return None
# 1) Delete vector DB chunks (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
@@ -658,6 +520,7 @@ def delete_user_file_impl(
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
try:
file_store.delete_file(user_file.file_id)
@@ -665,34 +528,26 @@ def delete_user_file_impl(
user_file_id_to_plaintext_file_name(user_file.id)
)
except Exception as e:
# This block executed only if the file is not found in the filestore
task_logger.exception(
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
)
# 3) Finally, delete the UserFile row
db_session.delete(user_file)
db_session.commit()
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
task_logger.info(
f"process_single_user_file_delete - Completed id={user_file_id}"
)
except Exception as e:
task_logger.exception(
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
raise
return None
finally:
if file_lock is not None and file_lock.owned():
if file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
delete_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
return None
@shared_task(
@@ -702,8 +557,8 @@ def process_single_user_file_delete(
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files needing project sync and enqueue per-file tasks."""
task_logger.info("Starting")
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -715,25 +570,13 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
queue_depth = get_user_file_project_sync_queue_depth(self.app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
task_logger.warning(
f"Queue depth {queue_depth} exceeds "
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
sa.and_(
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.needs_project_sync.is_(True),
UserFile.status == UserFileStatus.COMPLETED,
)
)
@@ -743,65 +586,58 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
if not enqueue_user_file_project_sync_task(
celery_app=self.app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
):
skipped_guard += 1
continue
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"Enqueued {enqueued} "
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
def project_sync_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
"""Core implementation for syncing a user file's project/persona metadata.
"""Process a single user file project sync."""
task_logger.info(
f"process_single_user_file_project_sync - Starting id={user_file_id}"
)
When redis_locking=True, acquires a per-file Redis lock and clears the
queued-key guard (Celery path). When redis_locking=False, skips Redis
operations (BackgroundTask path).
"""
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock = redis_client.lock(
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
)
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
return None
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.execute(
select(UserFile)
.where(UserFile.id == _as_uuid(user_file_id))
.options(selectinload(UserFile.assistants))
).scalar_one_or_none()
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"project_sync_user_file_impl - User file not found id={user_file_id}"
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
)
return
return None
# Sync project metadata to vector DB (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
@@ -822,25 +658,20 @@ def project_sync_user_file_impl(
]
project_ids = [project.id for project in user_file.projects]
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(
user_projects=project_ids,
personas=persona_ids,
),
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
f"project_sync_user_file_impl - User file id={user_file_id}"
f"process_single_user_file_project_sync - User file id={user_file_id}"
)
user_file.needs_project_sync = False
user_file.needs_persona_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
@@ -849,22 +680,11 @@ def project_sync_user_file_impl(
except Exception as e:
task_logger.exception(
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
raise
return None
finally:
if file_lock is not None and file_lock.owned():
if file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
project_sync_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
return None

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -158,7 +156,36 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_batch.append(sanitize_document_for_postgres(doc))
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@@ -575,13 +602,10 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch_cleaned,
nodes=hierarchy_node_batch,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -600,7 +624,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -1,307 +0,0 @@
"""Periodic poller for NO_VECTOR_DB deployments.
Replaces Celery Beat and background workers with a lightweight daemon thread
that runs from the API server process. Two responsibilities:
1. Recovery polling (every 30 s): re-processes user files stuck in
PROCESSING / DELETING / needs_sync states via the drain loops defined
in ``task_utils.py``.
2. Periodic task execution (configurable intervals): runs LLM model updates
and scheduled evals at their configured cadences, with Postgres advisory
lock deduplication across multiple API server instances.
"""
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import field
from onyx.utils.logger import setup_logger
logger = setup_logger()
RECOVERY_INTERVAL_SECONDS = 30
PERIODIC_TASK_LOCK_BASE = 20_000
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
# ------------------------------------------------------------------
# Periodic task definitions
# ------------------------------------------------------------------
_NEVER_RAN: float = -1e18
@dataclass
class _PeriodicTaskDef:
name: str
interval_seconds: float
lock_id: int
run_fn: Callable[[], None]
last_run_at: float = field(default=_NEVER_RAN)
def _run_auto_llm_update() -> None:
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
if not AUTO_LLM_CONFIG_URL:
return
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.well_known_providers.auto_update_service import (
sync_llm_models_from_github,
)
with get_session_with_current_tenant() as db_session:
sync_llm_models_from_github(db_session)
def _run_cache_cleanup() -> None:
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
cleanup_expired_cache_entries()
def _run_scheduled_eval() -> None:
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
if not all(
[
BRAINTRUST_API_KEY,
SCHEDULED_EVAL_PROJECT,
SCHEDULED_EVAL_DATASET_NAMES,
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
]
):
return
from datetime import datetime
from datetime import timezone
from onyx.evals.eval import run_eval
from onyx.evals.models import EvalConfigurationOptions
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
try:
run_eval(
configuration=EvalConfigurationOptions(
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
dataset_name=dataset_name,
no_send_logs=False,
braintrust_project=SCHEDULED_EVAL_PROJECT,
experiment_name=f"{dataset_name} - {run_timestamp}",
),
remote_dataset_name=dataset_name,
)
except Exception:
logger.exception(
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
)
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
tasks: list[_PeriodicTaskDef] = []
if CACHE_BACKEND == CacheBackendType.POSTGRES:
tasks.append(
_PeriodicTaskDef(
name="cache-cleanup",
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
run_fn=_run_cache_cleanup,
)
)
if AUTO_LLM_CONFIG_URL:
tasks.append(
_PeriodicTaskDef(
name="auto-llm-update",
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE,
run_fn=_run_auto_llm_update,
)
)
if SCHEDULED_EVAL_DATASET_NAMES:
tasks.append(
_PeriodicTaskDef(
name="scheduled-eval",
interval_seconds=7 * 24 * 3600,
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
run_fn=_run_scheduled_eval,
)
)
return tasks
# ------------------------------------------------------------------
# Periodic task runner with advisory-lock-guarded claim
# ------------------------------------------------------------------
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
"""Atomically check whether *task_def* should run and record a claim.
Uses a transaction-scoped advisory lock for atomicity combined with a
``KVStore`` timestamp for cross-instance dedup. The DB session is held
only for this brief claim transaction, not during task execution.
"""
from datetime import datetime
from datetime import timezone
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import KVStore
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
with get_session_with_current_tenant() as db_session:
acquired = db_session.execute(
text("SELECT pg_try_advisory_xact_lock(:id)"),
{"id": task_def.lock_id},
).scalar()
if not acquired:
return False
row = db_session.query(KVStore).filter_by(key=kv_key).first()
if row and row.value is not None:
last_claimed = datetime.fromisoformat(str(row.value))
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
if elapsed < task_def.interval_seconds:
return False
now_ts = datetime.now(timezone.utc).isoformat()
if row:
row.value = now_ts
else:
db_session.add(KVStore(key=kv_key, value=now_ts))
db_session.commit()
return True
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
now = time.monotonic()
if now - task_def.last_run_at < task_def.interval_seconds:
return
if not _try_claim_task(task_def):
return
try:
task_def.run_fn()
task_def.last_run_at = now
except Exception:
logger.exception(
f"Periodic poller - Error running periodic task {task_def.name}"
)
# ------------------------------------------------------------------
# Recovery / drain loop runner
# ------------------------------------------------------------------
def _run_drain_loops(tenant_id: str) -> None:
from onyx.background.task_utils import drain_delete_loop
from onyx.background.task_utils import drain_processing_loop
from onyx.background.task_utils import drain_project_sync_loop
drain_processing_loop(tenant_id)
drain_delete_loop(tenant_id)
drain_project_sync_loop(tenant_id)
# ------------------------------------------------------------------
# Startup recovery (10g)
# ------------------------------------------------------------------
def recover_stuck_user_files(tenant_id: str) -> None:
"""Run all drain loops once to re-process files left in intermediate states.
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
"""
logger.info("recover_stuck_user_files - Checking for stuck user files")
try:
_run_drain_loops(tenant_id)
except Exception:
logger.exception("recover_stuck_user_files - Error during recovery")
# ------------------------------------------------------------------
# Daemon thread (10f)
# ------------------------------------------------------------------
_shutdown_event = threading.Event()
_poller_thread: threading.Thread | None = None
def _poller_loop(tenant_id: str) -> None:
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
periodic_tasks = _build_periodic_tasks()
logger.info(
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
f"{[t.name for t in periodic_tasks]}"
)
while not _shutdown_event.is_set():
try:
_run_drain_loops(tenant_id)
except Exception:
logger.exception("Periodic poller - Error in recovery polling")
for task_def in periodic_tasks:
try:
_try_run_periodic_task(task_def)
except Exception:
logger.exception(
f"Periodic poller - Unhandled error checking task {task_def.name}"
)
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
def start_periodic_poller(tenant_id: str) -> None:
"""Start the periodic poller daemon thread."""
global _poller_thread # noqa: PLW0603
_shutdown_event.clear()
_poller_thread = threading.Thread(
target=_poller_loop,
args=(tenant_id,),
daemon=True,
name="no-vectordb-periodic-poller",
)
_poller_thread.start()
logger.info("Periodic poller thread started")
def stop_periodic_poller() -> None:
"""Signal the periodic poller to stop and wait for it to exit."""
global _poller_thread # noqa: PLW0603
if _poller_thread is None:
return
_shutdown_event.set()
_poller_thread.join(timeout=10)
if _poller_thread.is_alive():
logger.warning("Periodic poller thread did not stop within timeout")
_poller_thread = None
logger.info("Periodic poller thread stopped")

Some files were not shown because too many files have changed in this diff Show More