Compare commits

..

33 Commits

Author SHA1 Message Date
justin-tahara
2032b76fbf chore(release): Fixing Release Branch 2026-02-20 14:45:30 -08:00
Jamison Lahman
055b30b00e chore(fe): fix drop-down overflow in API Key modal (#8574) 2026-02-20 14:26:31 -08:00
Jamison Lahman
360a4cf591 chore(fe): remove close button from image gen tooltip (#8585) 2026-02-20 14:13:16 -08:00
Jamison Lahman
3d3cab9f91 fix(fe): popover width can fit trigger element (#8624) 2026-02-20 14:13:16 -08:00
Justin Tahara
6120d012ba feat(web): FE Changes for Brave Web Search 3/3 (#8597) 2026-02-20 11:29:02 -08:00
Evan Lohn
3e7e2e93f2 fix: search tool enabled when nothing selected 2026-02-20 11:05:46 -08:00
Justin Tahara
ccf482fa3b hotfix/web 2026-02-20 11:03:32 -08:00
Justin Tahara
fd45a612da feat(web): Initial Framework for Brave Web Search 1/3 (#8594) 2026-02-20 10:58:41 -08:00
Danelegend
c444d8883b fix: /llm/provider route returns all providers (#8545) 2026-02-20 10:48:56 -08:00
SubashMohan
9947837f9f fix: update SourceTag component to use variant prop for sizing (#8582) 2026-02-20 11:54:18 +05:30
SubashMohan
bc324a8070 fix(ui): fix few common ui bugs (#8425) 2026-02-20 11:54:04 +05:30
SubashMohan
26f648c24a fix(chatpage): Improve agent message layout, sidebar nesting, and icon fixes (#8224) 2026-02-20 10:49:23 +05:30
SubashMohan
638f20f5f3 fix(timeline): reduce agent message re-renders with referential stability in usePacedTurnGroups (#8265) 2026-02-20 10:49:04 +05:30
Jamison Lahman
f6ee57f523 chore(gha): rm nightly license scan workflow (#8541) 2026-02-19 20:03:58 -08:00
Justin Tahara
aae6fc7aac fix(desktop): Link clicking within App (#8493) 2026-02-19 17:44:32 -08:00
Justin Tahara
5d7a664250 fix(bedrock): Fixing toolConfig call (#8342) 2026-02-19 17:44:11 -08:00
Wenxi
e7386490bf fix(manage-users): exclude slack users from /users list (#8602) 2026-02-19 17:09:47 -08:00
Wenxi
106e10a143 fix: open_url broken on non-normalized urls and enable web crawl tests (#8508) 2026-02-19 17:09:47 -08:00
Wenxi
513f430a1b refactor: connector config refresh elements/cleanup (#8428) 2026-02-19 17:09:47 -08:00
Wenxi
696d73822f fix: remove log error when authtype is not set (#8399) 2026-02-19 17:09:47 -08:00
Wenxi
bfcc5a20a2 chore: make chatbackgrounds local assets for air-gapped envs (#8381) 2026-02-19 17:09:47 -08:00
Wenxi
efe3613354 fix: allow basic users to share agents (#8269) 2026-02-19 17:09:47 -08:00
Nikolas Garza
62405bdc42 fix(ee): small ux fixes for licensing (#8498) 2026-02-19 14:32:28 -08:00
Yuhong Sun
8f505dc45f chore: License update (No change, just touchup) (#8460) 2026-02-19 14:32:28 -08:00
Jessica Singh
75f0db4fe5 chore(bulk invite): free trial limit (#8378) 2026-02-19 14:32:28 -08:00
Nikolas Garza
f0a5c579a3 feat(auth): enforce seat limits on all user creation paths (#8401) 2026-02-19 14:32:28 -08:00
Nikolas Garza
293bf30847 fix(billing): exclude inactive users from seat counts and allow users page when gated (#8397) 2026-02-19 14:32:28 -08:00
Nikolas Garza
8774ca3b0f feat(ee): gate access only when legacy EE flag is set and no license exists (#8368) 2026-02-19 14:32:28 -08:00
Nikolas Garza
016a73f85f fix(ee): follow HTTP→HTTPS redirects in forward_to_control_plane (#8360) 2026-02-19 14:32:28 -08:00
Wenxi
2eddb4e23e fix: upgrade plan page nits (#8346) 2026-02-19 14:32:28 -08:00
Nikolas Garza
0a61660a59 fix(ee): copy license public key into Docker image (#8322) 2026-02-19 14:32:28 -08:00
Danelegend
a10599e76e fix: model config not populating flow during sync (#8542) 2026-02-18 17:11:52 -08:00
Nikolas Garza
b3d3f7af76 feat(ee): Enable license enforcement by default (#8270) 2026-02-09 20:43:33 -08:00
1057 changed files with 22103 additions and 80877 deletions

View File

@@ -1 +0,0 @@
../.cursor/skills

View File

@@ -1,16 +0,0 @@
{
"mcpServers": {
"Playwright": {
"command": "npx",
"args": [
"@playwright/mcp"
]
},
"Linear": {
"url": "https://mcp.linear.app/mcp"
},
"Figma": {
"url": "https://mcp.figma.com/mcp"
}
}
}

View File

@@ -1,248 +0,0 @@
---
name: playwright-e2e-tests
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
---
# Playwright E2E Tests
## Project Layout
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
- **Config**: `web/playwright.config.ts`
- **Utilities**: `web/tests/e2e/utils/`
- **Constants**: `web/tests/e2e/constants.ts`
- **Global setup**: `web/tests/e2e/global-setup.ts`
- **Output**: `web/output/playwright/`
## Imports
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
```
All new files should be `.ts`, not `.js`.
## Running Tests
```bash
# Run a specific test file
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
# Run a specific project
npx playwright test --project admin
npx playwright test --project exclusive
```
## Test Projects
| Project | Description | Parallelism |
|---------|-------------|-------------|
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
## Authentication
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
- Server readiness check (polls health endpoint, 60s timeout)
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
- Setting display name to `"worker"` for each worker user
- Promoting admin2 to admin role
- Ensuring a public LLM provider exists
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
When a test needs a different user, use API-based login — never drive the login UI:
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAs(page, "admin2");
// Log in as the worker-specific user (preferred for test isolation):
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
```
## Test Structure
Tests start pre-authenticated as admin — navigate and test directly:
```typescript
import { test, expect } from "@playwright/test";
test.describe("Feature Name", () => {
test("should describe expected behavior clearly", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Already authenticated as admin — go straight to testing
});
});
```
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
```typescript
import { test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
});
```
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
## Key Utilities
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
Backend API client for test setup/teardown. Key methods:
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
- **Chat**: `createChatSession()`, `deleteChatSession()`
### `chatActions` (`@tests/e2e/utils/chatActions`)
- `sendMessage(page, message)` — sends a message and waits for AI response
- `startNewChat(page)` — clicks new-chat button and waits for intro
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
- `switchModel(page, modelName)` — switches LLM model via popover
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
- `expectElementScreenshot(locator, { name, mask?, hide? })`
- Controlled by `VISUAL_REGRESSION=true` env var
### `theme` (`@tests/e2e/utils/theme`)
- `THEMES``["light", "dark"] as const` array for iterating over both themes
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
```typescript
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
for (const theme of THEMES) {
test.describe(`Feature (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("renders correctly", async ({ page }) => {
await page.goto("/app");
await expectScreenshot(page, { name: `feature-${theme}` });
});
});
}
```
### `tools` (`@tests/e2e/utils/tools`)
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
- `openActionManagement(page)` — opens the tool management popover
## Locator Strategy
Use locators in this priority order:
1. **`data-testid` / `aria-label`** — preferred for Onyx components
```typescript
page.getByTestId("AppSidebar/new-session")
page.getByLabel("admin-page-title")
```
2. **Role-based** — for standard HTML elements
```typescript
page.getByRole("button", { name: "Create" })
page.getByRole("dialog")
```
3. **Text/Label** — for visible text content
```typescript
page.getByText("Custom Assistant")
page.getByLabel("Email")
```
4. **CSS selectors** — last resort, only when above won't work
```typescript
page.locator('input[name="name"]')
page.locator("#onyx-chat-input-textarea")
```
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
## Assertions
Use web-first assertions — they auto-retry until the condition is met:
```typescript
// Visibility
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
// Text content
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
// Count
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
// URL
await expect(page).toHaveURL(/chatId=/);
// Element state
await expect(toggle).toBeChecked();
await expect(button).toBeEnabled();
```
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
## Waiting Strategy
```typescript
// Wait for load state after navigation
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Wait for specific element
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
// Wait for URL change
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
// Wait for network response
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
```
## Best Practices
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does

View File

@@ -91,8 +91,8 @@ jobs:
BUILD_WEB_CLOUD=true
else
BUILD_WEB=true
# Only build desktop for semver tags (excluding beta)
if [[ "$IS_VERSION_TAG" == "true" ]] && [[ "$IS_BETA" != "true" ]]; then
# Skip desktop builds on beta tags and nightly runs
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
BUILD_DESKTOP=true
fi
fi
@@ -640,7 +640,6 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
@@ -722,7 +721,6 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,10 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -127,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -41,7 +41,8 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
with:
uv_version: "0.9.9"
@@ -91,7 +92,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -46,7 +46,6 @@ jobs:
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
editions: ${{ steps.set-editions.outputs.editions }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
@@ -57,7 +56,7 @@ jobs:
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" ! -name "no_vectordb" -exec basename {} \; | sort)
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
@@ -73,16 +72,6 @@ jobs:
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
- name: Determine editions to test
id: set-editions
run: |
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
else
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
fi
build-backend-image:
runs-on:
[
@@ -278,7 +267,7 @@ jobs:
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
@@ -286,7 +275,6 @@ jobs:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -310,11 +298,12 @@ jobs:
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
EDITION: ${{ matrix.edition }}
run: |
# Base config shared by both editions
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -323,20 +312,11 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
# EE-only config
if [ "$EDITION" = "ee" ]; then
cat <<EOF >> deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
MCP_SERVER_ENABLED=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
- name: Start Docker containers
run: |
@@ -399,14 +379,14 @@ jobs:
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -464,143 +444,10 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
no-vectordb-tests:
needs: [build-backend-image, build-integration-image]
runs-on:
[
runs-on,
runner=4cpu-linux-arm64,
"run-id=${{ github.run_id }}-no-vectordb-tests",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create .env file for no-vectordb Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
DISABLE_VECTOR_DB=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
EOF
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
- name: Start Docker containers (no-vectordb)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
relational_db \
cache \
minio \
api_server \
background \
-d
id: start_docker_no_vectordb
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script (no-vectordb)..."
start_time=$(date +%s)
timeout=300
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in $timeout seconds."
exit 1
fi
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "API server is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error; retrying..."
else
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
fi
sleep 5
done
- name: Run No-VectorDB Integration Tests
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running no-vectordb integration tests..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e TEST_WEB_HOSTNAME=test-runner \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/tests/no_vectordb
- name: Dump API server logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
- name: Dump all-container logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
- name: Upload logs (no-vectordb)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-no-vectordb
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
- name: Stop Docker containers (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
multitenant-tests:
needs:
[build-backend-image, build-model-server-image, build-integration-image]
@@ -740,7 +587,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -0,0 +1,443 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
types: [checks_requested]
push:
tags:
- "v*.*.*"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
all_dirs=""
for dir in $tests_dirs; do
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
done
for dir in $connector_dirs; do
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
done
# Remove trailing comma and wrap in array
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push integration test image with Docker Bake
env:
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests-mit:
needs:
[
discover-test-dirs,
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Create .env file for Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script..."
wait_for_service() {
local url=$1
local label=$2
local timeout=${3:-300} # default 5 minutes
local start_time
start_time=$(date +%s)
while true; do
local current_time
current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
exit 1
fi
local response
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
if [ "$response" = "200" ]; then
echo "${label} is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
else
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
}
wait_for_service "http://localhost:8080/health" "API server"
echo "Finished waiting for services."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -55,9 +55,6 @@ env:
MCP_SERVER_PUBLIC_HOST: host.docker.internal
MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
# Visual regression S3 bucket (shared across all jobs)
PLAYWRIGHT_S3_BUCKET: onyx-playwright-artifacts
jobs:
build-web-image:
runs-on:
@@ -245,9 +242,6 @@ jobs:
playwright-tests:
needs: [build-web-image, build-backend-image, build-model-server-image]
name: Playwright Tests (${{ matrix.project }})
permissions:
id-token: write # Required for OIDC-based AWS credential exchange (S3 access)
contents: read
runs-on:
- runs-on
- runner=8cpu-linux-arm64
@@ -303,7 +297,6 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
@@ -438,6 +431,8 @@ jobs:
env:
PROJECT: ${{ matrix.project }}
run: |
# Create test-results directory to ensure it exists for artifact upload
mkdir -p test-results
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
@@ -445,134 +440,9 @@ jobs:
with:
# Includes test results and trace.zip files
name: playwright-test-results-${{ matrix.project }}-${{ github.run_id }}
path: ./web/output/playwright/
path: ./web/test-results/
retention-days: 30
- name: Upload screenshots
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
path: ./web/output/screenshots/
retention-days: 30
# --- Visual Regression Diff ---
- name: Configure AWS credentials
if: always()
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Install the latest version of uv
if: always()
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Determine baseline revision
if: always()
id: baseline-rev
env:
EVENT_NAME: ${{ github.event_name }}
BASE_REF: ${{ github.event.pull_request.base.ref }}
MERGE_GROUP_BASE_REF: ${{ github.event.merge_group.base_ref }}
GH_REF: ${{ github.ref }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ "${EVENT_NAME}" = "pull_request" ]; then
# PRs compare against the base branch (e.g. main, release/2.5)
echo "rev=${BASE_REF}" >> "$GITHUB_OUTPUT"
elif [ "${EVENT_NAME}" = "merge_group" ]; then
# Merge queue compares against the target branch (e.g. refs/heads/main -> main)
echo "rev=${MERGE_GROUP_BASE_REF#refs/heads/}" >> "$GITHUB_OUTPUT"
elif [[ "${GH_REF}" == refs/tags/* ]]; then
# Tag builds compare against the tag name
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
else
# Push builds (main, release/*) compare against the branch name
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
fi
- name: Generate screenshot diff report
if: always()
env:
PROJECT: ${{ matrix.project }}
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
run: |
uv run --no-sync --with onyx-devtools ods screenshot-diff compare \
--project "${PROJECT}" \
--rev "${BASELINE_REV}"
- name: Upload visual diff report to S3
if: always()
env:
PROJECT: ${{ matrix.project }}
PR_NUMBER: ${{ github.event.pull_request.number }}
RUN_ID: ${{ github.run_id }}
run: |
SUMMARY_FILE="web/output/screenshot-diff/${PROJECT}/summary.json"
if [ ! -f "${SUMMARY_FILE}" ]; then
echo "No summary file found — skipping S3 upload."
exit 0
fi
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
if [ "${HAS_DIFF}" != "true" ]; then
echo "No visual differences for ${PROJECT} — skipping S3 upload."
exit 0
fi
aws s3 sync "web/output/screenshot-diff/${PROJECT}/" \
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
- name: Upload visual diff summary
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: screenshot-diff-summary-${{ matrix.project }}
path: ./web/output/screenshot-diff/${{ matrix.project }}/summary.json
if-no-files-found: ignore
retention-days: 5
- name: Upload visual diff report artifact
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
path: ./web/output/screenshot-diff/${{ matrix.project }}/
if-no-files-found: ignore
retention-days: 30
- name: Update S3 baselines
if: >-
success() && (
github.ref == 'refs/heads/main' ||
startsWith(github.ref, 'refs/heads/release/') ||
startsWith(github.ref, 'refs/tags/v') ||
(
github.event_name == 'merge_group' && (
github.event.merge_group.base_ref == 'refs/heads/main' ||
startsWith(github.event.merge_group.base_ref, 'refs/heads/release/')
)
)
)
env:
PROJECT: ${{ matrix.project }}
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
run: |
if [ -d "web/output/screenshots/" ] && [ "$(ls -A web/output/screenshots/)" ]; then
uv run --no-sync --with onyx-devtools ods screenshot-diff upload-baselines \
--project "${PROJECT}" \
--rev "${BASELINE_REV}" \
--delete
else
echo "No screenshots to upload for ${PROJECT} — skipping baseline update."
fi
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
@@ -590,98 +460,6 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: >-
always() &&
github.event_name == 'pull_request' &&
needs.playwright-tests.result != 'cancelled'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
with:
pattern: screenshot-diff-summary-*
path: summaries/
- name: Post combined PR comment
env:
GH_TOKEN: ${{ github.token }}
PR_NUMBER: ${{ github.event.pull_request.number }}
RUN_ID: ${{ github.run_id }}
REPO: ${{ github.repository }}
S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
run: |
MARKER="<!-- visual-regression-report -->"
# Build the markdown table from all summary files
TABLE_HEADER="| Project | Changed | Added | Removed | Unchanged | Report |"
TABLE_DIVIDER="|---------|---------|-------|---------|-----------|--------|"
TABLE_ROWS=""
HAS_ANY_SUMMARY=false
for SUMMARY_DIR in summaries/screenshot-diff-summary-*/; do
SUMMARY_FILE="${SUMMARY_DIR}summary.json"
if [ ! -f "${SUMMARY_FILE}" ]; then
continue
fi
HAS_ANY_SUMMARY=true
PROJECT=$(jq -r '.project' "${SUMMARY_FILE}")
CHANGED=$(jq -r '.changed' "${SUMMARY_FILE}")
ADDED=$(jq -r '.added' "${SUMMARY_FILE}")
REMOVED=$(jq -r '.removed' "${SUMMARY_FILE}")
UNCHANGED=$(jq -r '.unchanged' "${SUMMARY_FILE}")
TOTAL=$(jq -r '.total' "${SUMMARY_FILE}")
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
if [ "${TOTAL}" = "0" ]; then
REPORT_LINK="_No screenshots_"
elif [ "${HAS_DIFF}" = "true" ]; then
REPORT_URL="https://${S3_BUCKET}.s3.us-east-2.amazonaws.com/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/index.html"
REPORT_LINK="[View Report](${REPORT_URL})"
else
REPORT_LINK="✅ No changes"
fi
TABLE_ROWS="${TABLE_ROWS}| \`${PROJECT}\` | ${CHANGED} | ${ADDED} | ${REMOVED} | ${UNCHANGED} | ${REPORT_LINK} |\n"
done
if [ "${HAS_ANY_SUMMARY}" = "false" ]; then
echo "No visual diff summaries found — skipping PR comment."
exit 0
fi
BODY=$(printf '%s\n' \
"${MARKER}" \
"### 🖼️ Visual Regression Report" \
"" \
"${TABLE_HEADER}" \
"${TABLE_DIVIDER}" \
"$(printf '%b' "${TABLE_ROWS}")")
# Upsert: find existing comment with the marker, or create a new one
EXISTING_COMMENT_ID=$(gh api \
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
--jq ".[] | select(.body | startswith(\"${MARKER}\")) | .id" \
2>/dev/null | head -1)
if [ -n "${EXISTING_COMMENT_ID}" ]; then
gh api \
--method PATCH \
"repos/${REPO}/issues/comments/${EXISTING_COMMENT_ID}" \
-f body="${BODY}"
else
gh api \
--method POST \
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
-f body="${BODY}"
fi
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
@@ -692,3 +470,48 @@ jobs:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
# chromatic-tests:
# name: Chromatic Tests
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
# with:
# fetch-depth: 0
# - name: Setup node
# uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
# with:
# node-version: 22
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
# - name: Download Playwright test results
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -1,73 +0,0 @@
name: Preview Deployment
env:
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }}
VERCEL_CLI: vercel@50.14.1
on:
push:
branches-ignore:
- main
paths:
- "web/**"
permissions:
contents: read
pull-requests: write
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Pull Vercel Environment Information
run: npx --yes ${{ env.VERCEL_CLI }} pull --yes --environment=preview --token=${{ secrets.VERCEL_TOKEN }}
- name: Build Project Artifacts
run: npx --yes ${{ env.VERCEL_CLI }} build --token=${{ secrets.VERCEL_TOKEN }}
- name: Deploy Project Artifacts to Vercel
id: deploy
run: |
DEPLOYMENT_URL=$(npx --yes ${{ env.VERCEL_CLI }} deploy --prebuilt --token=${{ secrets.VERCEL_TOKEN }})
echo "url=$DEPLOYMENT_URL" >> "$GITHUB_OUTPUT"
- name: Update PR comment with deployment URL
if: always() && steps.deploy.outputs.url
env:
GH_TOKEN: ${{ github.token }}
DEPLOYMENT_URL: ${{ steps.deploy.outputs.url }}
run: |
# Find the PR for this branch
PR_NUMBER=$(gh pr list --head "$GITHUB_REF_NAME" --json number --jq '.[0].number')
if [ -z "$PR_NUMBER" ]; then
echo "No open PR found for branch $GITHUB_REF_NAME, skipping comment."
exit 0
fi
COMMENT_MARKER="<!-- preview-deployment -->"
COMMENT_BODY="$COMMENT_MARKER
**Preview Deployment**
| Status | Preview | Commit | Updated |
| --- | --- | --- | --- |
| ✅ | $DEPLOYMENT_URL | \`${GITHUB_SHA::7}\` | $(date -u '+%Y-%m-%d %H:%M:%S UTC') |"
# Find existing comment by marker
EXISTING_COMMENT_ID=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \
--jq ".[] | select(.body | startswith(\"$COMMENT_MARKER\")) | .id" | head -1)
if [ -n "$EXISTING_COMMENT_ID" ]; then
gh api "repos/$GITHUB_REPOSITORY/issues/comments/$EXISTING_COMMENT_ID" \
--method PATCH --field body="$COMMENT_BODY"
else
gh pr comment "$PR_NUMBER" --body "$COMMENT_BODY"
fi

View File

@@ -1,290 +0,0 @@
name: Build and Push Sandbox Image on Tag
on:
push:
tags:
- "experimental-cc4a.*"
# Restrictive defaults; jobs declare what they need.
permissions: {}
jobs:
check-sandbox-changes:
runs-on: ubuntu-slim
timeout-minutes: 10
permissions:
contents: read
outputs:
sandbox-changed: ${{ steps.check.outputs.sandbox-changed }}
new-version: ${{ steps.version.outputs.new-version }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
fetch-depth: 0
- name: Check for sandbox-relevant file changes
id: check
run: |
# Get the previous tag to diff against
CURRENT_TAG="${GITHUB_REF_NAME}"
PREVIOUS_TAG=$(git tag --sort=-creatordate | grep '^experimental-cc4a\.' | grep -v "^${CURRENT_TAG}$" | head -n 1)
if [ -z "$PREVIOUS_TAG" ]; then
echo "No previous experimental-cc4a tag found, building unconditionally"
echo "sandbox-changed=true" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "Comparing ${PREVIOUS_TAG}..${CURRENT_TAG}"
# Check if any sandbox-relevant files changed
SANDBOX_PATHS=(
"backend/onyx/server/features/build/sandbox/"
)
CHANGED=false
for path in "${SANDBOX_PATHS[@]}"; do
if git diff --name-only "${PREVIOUS_TAG}..${CURRENT_TAG}" -- "$path" | grep -q .; then
echo "Changes detected in: $path"
CHANGED=true
break
fi
done
echo "sandbox-changed=$CHANGED" >> "$GITHUB_OUTPUT"
- name: Determine new sandbox version
id: version
if: steps.check.outputs.sandbox-changed == 'true'
run: |
# Query Docker Hub for the latest versioned tag
LATEST_TAG=$(curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags?page_size=100" \
| jq -r '.results[].name' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| sort -V \
| tail -n 1)
if [ -z "$LATEST_TAG" ]; then
echo "No existing version tags found on Docker Hub, starting at 0.1.1"
NEW_VERSION="0.1.1"
else
CURRENT_VERSION="${LATEST_TAG#v}"
echo "Latest version on Docker Hub: $CURRENT_VERSION"
# Increment patch version
MAJOR=$(echo "$CURRENT_VERSION" | cut -d. -f1)
MINOR=$(echo "$CURRENT_VERSION" | cut -d. -f2)
PATCH=$(echo "$CURRENT_VERSION" | cut -d. -f3)
NEW_PATCH=$((PATCH + 1))
NEW_VERSION="${MAJOR}.${MINOR}.${NEW_PATCH}"
fi
echo "New version: $NEW_VERSION"
echo "new-version=$NEW_VERSION" >> "$GITHUB_OUTPUT"
build-sandbox-amd64:
needs: check-sandbox-changes
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-sandbox-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
permissions:
contents: read
id-token: write
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/sandbox
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
build-sandbox-arm64:
needs: check-sandbox-changes
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-sandbox-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
permissions:
contents: read
id-token: write
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/sandbox
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
merge-sandbox:
needs:
- check-sandbox-changes
- build-sandbox-amd64
- build-sandbox-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-sandbox
- extras=ecr-cache
timeout-minutes: 30
environment: release
permissions:
id-token: write
env:
REGISTRY_IMAGE: onyxdotapp/sandbox
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=v${{ needs.check-sandbox-changes.outputs.new-version }}
type=raw,value=latest
- name: Create and push manifest
env:
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-sandbox-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-sandbox-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES

View File

@@ -5,8 +5,6 @@ on:
branches: ["main"]
pull_request:
branches: ["**"]
paths:
- ".github/**"
permissions: {}
@@ -23,18 +21,29 @@ jobs:
with:
persist-credentials: false
- name: Detect changes
id: filter
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
with:
filters: |
zizmor:
- '.github/**'
- name: Install the latest version of uv
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Run zizmor
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif

2
.gitignore vendored
View File

@@ -6,8 +6,6 @@
!/.vscode/tasks.template.jsonc
.zed
.cursor
!/.cursor/mcp.json
!/.cursor/skills/
# macos
.DS_store

6
.vscode/launch.json vendored
View File

@@ -246,7 +246,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup,opensearch_migration"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
],
"presentation": {
"group": "2"
@@ -275,7 +275,7 @@
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete"
],
"presentation": {
"group": "2"
@@ -419,7 +419,7 @@
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching"
"connector_doc_fetching,user_files_indexing"
],
"presentation": {
"group": "2"

View File

@@ -144,10 +144,6 @@ function.
If you make any updates to a celery worker and you want to test these changes, you will need
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
**Task Time Limits**:
Since all tasks are executed in thread pools, the time limit features of Celery are silently
disabled and won't work. Timeout logic must be implemented within the task itself.
### Code Quality
```bash

View File

@@ -474,7 +474,7 @@ def run_migrations_online() -> None:
if connectable is not None:
# pytest-alembic is providing an engine - use it directly
logger.debug("run_migrations_online starting (pytest-alembic mode).")
logger.info("run_migrations_online starting (pytest-alembic mode).")
# For pytest-alembic, we use the default schema (public)
schema_name = context.config.attributes.get(

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,33 +0,0 @@
"""add default_app_mode to user
Revision ID: 114a638452db
Revises: feead2911109
Create Date: 2026-02-09 18:57:08.274640
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "114a638452db"
down_revision = "feead2911109"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"default_app_mode",
sa.String(),
nullable=False,
server_default="CHAT",
),
)
def downgrade() -> None:
op.drop_column("user", "default_app_mode")

View File

@@ -11,6 +11,7 @@ import sqlalchemy as sa
from urllib.parse import urlparse, urlunparse
from httpx import HTTPStatusError
import httpx
from onyx.document_index.factory import get_default_document_index
from onyx.db.search_settings import SearchSettings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
@@ -518,11 +519,15 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
def upgrade() -> None:
if SKIP_CANON_DRIVE_IDS:
return
current_search_settings, _ = active_search_settings()
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings,
future_search_settings,
)
# Get the index name
if hasattr(current_search_settings, "index_name"):
index_name = current_search_settings.index_name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"

View File

@@ -1,27 +0,0 @@
"""add_user_preferences
Revision ID: 175ea04c7087
Revises: d56ffa94ca32
Create Date: 2026-02-04 18:16:24.830873
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "175ea04c7087"
down_revision = "d56ffa94ca32"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("user_preferences", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "user_preferences")

View File

@@ -1,71 +0,0 @@
"""Migrate to contextual rag model
Revision ID: 19c0ccb01687
Revises: 9c54986124c6
Create Date: 2026-02-12 11:21:41.798037
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "19c0ccb01687"
down_revision = "9c54986124c6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
# when the table was created with only CHAT/VISION values.
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=20),
existing_type=sa.String(length=10),
existing_nullable=False,
)
# For every search_settings row that has contextual rag configured,
# create an llm_model_flow entry. is_default is TRUE if the row
# belongs to the PRESENT search settings, FALSE otherwise.
op.execute(
"""
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
SELECT DISTINCT
'CONTEXTUAL_RAG',
mc.id,
(ss.status = 'PRESENT')
FROM search_settings ss
JOIN llm_provider lp
ON lp.name = ss.contextual_rag_llm_provider
JOIN model_configuration mc
ON mc.llm_provider_id = lp.id
AND mc.name = ss.contextual_rag_llm_name
WHERE ss.enable_contextual_rag = TRUE
AND ss.contextual_rag_llm_name IS NOT NULL
AND ss.contextual_rag_llm_provider IS NOT NULL
ON CONFLICT (llm_model_flow_type, model_configuration_id)
DO UPDATE SET is_default = EXCLUDED.is_default
WHERE EXCLUDED.is_default = TRUE
"""
)
def downgrade() -> None:
op.execute(
"""
DELETE FROM llm_model_flow
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
"""
)
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=10),
existing_type=sa.String(length=20),
existing_nullable=False,
)

View File

@@ -1,32 +0,0 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -1,31 +0,0 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -16,6 +16,7 @@ from typing import Generator
from alembic import op
import sqlalchemy as sa
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.db.search_settings import SearchSettings
from onyx.configs.app_configs import AUTH_TYPE
@@ -125,11 +126,14 @@ def remove_old_tags() -> None:
the document got reindexed, the old tag would not be removed.
This function removes those old tags by comparing it against the tags in vespa.
"""
current_search_settings, _ = active_search_settings()
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings, future_search_settings
)
# Get the index name
if hasattr(current_search_settings, "index_name"):
index_name = current_search_settings.index_name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"

View File

@@ -1,43 +0,0 @@
"""add chunk error and vespa count columns to opensearch tenant migration
Revision ID: 93c15d6a6fbb
Revises: d3fd499c829c
Create Date: 2026-02-11 23:07:34.576725
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "93c15d6a6fbb"
down_revision = "d3fd499c829c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"total_chunks_errored",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"total_chunks_in_vespa",
sa.Integer(),
nullable=False,
server_default="0",
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "total_chunks_in_vespa")
op.drop_column("opensearch_tenant_migration_record", "total_chunks_errored")

View File

@@ -1,124 +0,0 @@
"""add_scim_tables
Revision ID: 9c54986124c6
Revises: b51c6844d1df
Create Date: 2026-02-12 20:29:47.448614
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c54986124c6"
down_revision = "b51c6844d1df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"scim_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("hashed_token", sa.String(length=64), nullable=False),
sa.Column("token_display", sa.String(), nullable=False),
sa.Column(
"created_by_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hashed_token"),
)
op.create_table(
"scim_group_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_group_id"),
)
op.create_index(
op.f("ix_scim_group_mapping_external_id"),
"scim_group_mapping",
["external_id"],
unique=True,
)
op.create_table(
"scim_user_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id"),
)
op.create_index(
op.f("ix_scim_user_mapping_external_id"),
"scim_user_mapping",
["external_id"],
unique=True,
)
def downgrade() -> None:
op.drop_index(
op.f("ix_scim_user_mapping_external_id"),
table_name="scim_user_mapping",
)
op.drop_table("scim_user_mapping")
op.drop_index(
op.f("ix_scim_group_mapping_external_id"),
table_name="scim_group_mapping",
)
op.drop_table("scim_group_mapping")
op.drop_table("scim_token")

View File

@@ -1,81 +0,0 @@
"""seed_memory_tool and add enable_memory_tool to user
Revision ID: b51c6844d1df
Revises: 93c15d6a6fbb
Create Date: 2026-02-11 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b51c6844d1df"
down_revision = "93c15d6a6fbb"
branch_labels = None
depends_on = None
MEMORY_TOOL = {
"name": "MemoryTool",
"display_name": "Add Memory",
"description": "Save memories about the user for future conversations.",
"in_code_tool_id": "MemoryTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
existing = conn.execute(
sa.text(
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id = :in_code_tool_id"
),
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
MEMORY_TOOL,
)
else:
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
MEMORY_TOOL,
)
op.add_column(
"user",
sa.Column(
"enable_memory_tool",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
)
def downgrade() -> None:
op.drop_column("user", "enable_memory_tool")
conn = op.get_bind()
conn.execute(
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
)

View File

@@ -1,31 +0,0 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -1,102 +0,0 @@
"""add_file_reader_tool
Revision ID: d3fd499c829c
Revises: 114a638452db
Create Date: 2026-02-07 19:28:22.452337
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d3fd499c829c"
down_revision = "114a638452db"
branch_labels = None
depends_on = None
FILE_READER_TOOL = {
"name": "read_file",
"display_name": "File Reader",
"description": (
"Read sections of user-uploaded files by character offset. "
"Useful for inspecting large files that cannot fit entirely in context."
),
"in_code_tool_id": "FileReaderTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
FILE_READER_TOOL,
)
tool_id = existing[0]
else:
# Insert new tool
result = conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
RETURNING id
"""
),
FILE_READER_TOOL,
)
tool_id = result.scalar_one()
# Attach to the default persona (id=0) if not already attached
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": tool_id},
)
def downgrade() -> None:
conn = op.get_bind()
in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"]
# Remove persona associations first (FK constraint)
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE tool_id IN (
SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id
)
"""
),
{"in_code_tool_id": in_code_tool_id},
)
conn.execute(
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": in_code_tool_id},
)

View File

@@ -1,69 +0,0 @@
"""add_opensearch_tenant_migration_columns
Revision ID: feead2911109
Revises: d56ffa94ca32
Create Date: 2026-02-10 17:46:34.029937
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "feead2911109"
down_revision = "175ea04c7087"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column("vespa_visit_continuation_token", sa.Text(), nullable=True),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"total_chunks_migrated",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"migration_completed_at",
sa.DateTime(timezone=True),
nullable=True,
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"enable_opensearch_retrieval",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "enable_opensearch_retrieval")
op.drop_column("opensearch_tenant_migration_record", "migration_completed_at")
op.drop_column("opensearch_tenant_migration_record", "created_at")
op.drop_column("opensearch_tenant_migration_record", "total_chunks_migrated")
op.drop_column(
"opensearch_tenant_migration_record", "vespa_visit_continuation_token"
)

View File

@@ -1,15 +1,12 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -1,14 +1,11 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.heavy import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
)
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -1,11 +1,8 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.light import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
]
)
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -1,10 +1,7 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.monitoring import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.tenant_provisioning",
]
)
[
"ee.onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -1,15 +1,12 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.primary import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",
"ee.onyx.background.celery.tasks.ttl_management",
"ee.onyx.background.celery.tasks.usage_reporting",
]
)
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",
"ee.onyx.background.celery.tasks.ttl_management",
"ee.onyx.background.celery.tasks.usage_reporting",
]
)

View File

@@ -536,9 +536,7 @@ def connector_permission_sync_generator_task(
)
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
)
callback = PermissionSyncCallback(redis_connector, lock, r)
# pass in the capability to fetch all existing docs for the cc_pair
# this is can be used to determine documents that are "missing" and thus
@@ -578,13 +576,6 @@ def connector_permission_sync_generator_task(
tasks_generated = 0
docs_with_errors = 0
for doc_external_access in document_external_accesses:
if callback.should_stop():
raise RuntimeError(
f"Permission sync task timed out or stop signal detected: "
f"cc_pair={cc_pair_id} "
f"tasks_generated={tasks_generated}"
)
result = redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
@@ -941,7 +932,6 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
@@ -954,26 +944,11 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"PermissionSyncCallback - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002

View File

@@ -466,7 +466,6 @@ def connector_external_group_sync_generator_task(
def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
timeout_seconds: int = JOB_TIMEOUT,
) -> None:
# Create attempt record at the start
with get_session_with_current_tenant() as db_session:
@@ -519,23 +518,9 @@ def _perform_external_group_sync(
seen_users: set[str] = set() # Track unique users across all groups
total_groups_processed = 0
total_group_memberships_synced = 0
start_time = time.monotonic()
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
elapsed = time.monotonic() - start_time
if elapsed > timeout_seconds:
raise RuntimeError(
f"External group sync task timed out: "
f"cc_pair={cc_pair_id} "
f"elapsed={elapsed:.0f}s "
f"timeout={timeout_seconds}s "
f"groups_processed={total_groups_processed}"
)
external_user_group_batch.append(external_user_group)
# Track progress

View File

@@ -1,604 +0,0 @@
"""SCIM Data Access Layer.
All database operations for SCIM provisioning — token management, user
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
Usage from FastAPI::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.post("/tokens")
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
token = dal.create_token(name=..., hashed_token=..., ...)
dal.commit()
return token
Usage from background tasks::
with ScimDAL.from_tenant("tenant_abc") as dal:
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
dal.commit()
"""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ScimDAL(DAL):
"""Data Access Layer for SCIM provisioning operations.
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
when you want to persist changes. This follows the existing ``_no_commit``
convention and lets callers batch multiple operations into one transaction.
"""
# ------------------------------------------------------------------
# Token operations
# ------------------------------------------------------------------
def create_token(
self,
name: str,
hashed_token: str,
token_display: str,
created_by_id: UUID,
) -> ScimToken:
"""Create a new SCIM bearer token.
Only one token is active at a time — this method automatically revokes
all existing active tokens before creating the new one.
"""
# Revoke any currently active tokens
active_tokens = list(
self._session.scalars(
select(ScimToken).where(ScimToken.is_active.is_(True))
).all()
)
for t in active_tokens:
t.is_active = False
token = ScimToken(
name=name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=created_by_id,
)
self._session.add(token)
self._session.flush()
return token
def get_active_token(self) -> ScimToken | None:
"""Return the single currently active token, or None."""
return self._session.scalar(
select(ScimToken).where(ScimToken.is_active.is_(True))
)
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
"""Look up a token by its SHA-256 hash."""
return self._session.scalar(
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
)
def revoke_token(self, token_id: int) -> None:
"""Deactivate a token by ID.
Raises:
ValueError: If the token does not exist.
"""
token = self._session.get(ScimToken, token_id)
if not token:
raise ValueError(f"SCIM token with id {token_id} not found")
token.is_active = False
def update_token_last_used(self, token_id: int) -> None:
"""Update the last_used_at timestamp for a token."""
token = self._session.get(ScimToken, token_id)
if token:
token.last_used_at = func.now() # type: ignore[assignment]
# ------------------------------------------------------------------
# User mapping operations
# ------------------------------------------------------------------
def create_user_mapping(
self,
external_id: str,
user_id: UUID,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_user_mapping_by_external_id(
self, external_id: str
) -> ScimUserMapping | None:
"""Look up a user mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
)
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
"""Look up a user mapping by the Onyx user ID."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
)
def list_user_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimUserMapping], int]:
"""List user mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimUserMapping)
.order_by(ScimUserMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def update_user_mapping_external_id(
self,
mapping_id: int,
external_id: str,
) -> ScimUserMapping:
"""Update the external ID on a user mapping.
Raises:
ValueError: If the mapping does not exist.
"""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
mapping.external_id = external_id
return mapping
def delete_user_mapping(self, mapping_id: int) -> None:
"""Delete a user mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# User query operations
# ------------------------------------------------------------------
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
"""Fetch a user by email (case-insensitive)."""
return self._session.scalar(
select(User).where(func.lower(User.email) == func.lower(email))
)
def add_user(self, user: User) -> None:
"""Add a new user to the session and flush to assign an ID."""
self._session.add(user)
self._session.flush()
def update_user(
self,
user: User,
*,
email: str | None = None,
is_active: bool | None = None,
personal_name: str | None = None,
) -> None:
"""Update user attributes. Only sets fields that are provided."""
if email is not None:
user.email = email
if is_active is not None:
user.is_active = is_active
if personal_name is not None:
user.personal_name = personal_name
def deactivate_user(self, user: User) -> None:
"""Mark a user as inactive."""
user.is_active = False
def list_users(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
# Count total matching rows first, then paginate. SCIM uses 1-based
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
# ------------------------------------------------------------------
# Group mapping operations
# ------------------------------------------------------------------
def create_group_mapping(
self,
external_id: str,
user_group_id: int,
) -> ScimGroupMapping:
"""Create a mapping between a SCIM externalId and an Onyx user group."""
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_group_mapping_by_external_id(
self, external_id: str
) -> ScimGroupMapping | None:
"""Look up a group mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
)
def get_group_mapping_by_group_id(
self, user_group_id: int
) -> ScimGroupMapping | None:
"""Look up a group mapping by the Onyx user group ID."""
return self._session.scalar(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id == user_group_id
)
)
def list_group_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimGroupMapping], int]:
"""List group mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimGroupMapping)
.order_by(ScimGroupMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def delete_group_mapping(self, mapping_id: int) -> None:
"""Delete a group mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group query operations
# ------------------------------------------------------------------
def get_group(self, group_id: int) -> UserGroup | None:
"""Fetch a group by ID, returning None if deleted or missing."""
group = self._session.get(UserGroup, group_id)
if group and group.is_up_for_deletion:
return None
return group
def get_group_by_name(self, name: str) -> UserGroup | None:
"""Fetch a group by exact name."""
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
def add_group(self, group: UserGroup) -> None:
"""Add a new group to the session and flush to assign an ID."""
self._session.add(group)
self._session.flush()
def update_group(
self,
group: UserGroup,
*,
name: str | None = None,
) -> None:
"""Update group attributes and set the modification timestamp."""
if name is not None:
group.name = name
group.time_last_modified_by_user = func.now()
def delete_group(self, group: UserGroup) -> None:
"""Delete a group from the session."""
self._session.delete(group)
def list_groups(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[UserGroup, str | None]], int]:
"""Query groups with optional SCIM filter and pagination.
Returns:
A tuple of (list of (group, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(UserGroup.id == mapping.user_group_id)
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
groups = list(
self._session.scalars(
query.order_by(UserGroup.id).offset(offset).limit(count)
).all()
)
ext_id_map = self._get_group_external_ids([g.id for g in groups])
return [(g, ext_id_map.get(g.id)) for g in groups], total
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
"""Get group members as (user_id, email) pairs."""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
).all()
user_ids = [r.user_id for r in rels if r.user_id]
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users_by_id = {u.id: u for u in users}
return [
(
r.user_id,
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
)
for r in rels
if r.user_id
]
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
"""Return the subset of UUIDs that don't exist as users.
Returns an empty list if all IDs are valid.
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Add user-group relationships, ignoring duplicates."""
if not user_ids:
return
self._session.execute(
pg_insert(User__UserGroup)
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
.on_conflict_do_nothing(
index_elements=[
User__UserGroup.user_group_id,
User__UserGroup.user_id,
]
)
)
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Replace all members of a group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
)
self.upsert_group_members(group_id, user_ids)
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Remove specific members from a group."""
if not user_ids:
return
self._session.execute(
sa_delete(User__UserGroup).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.in_(user_ids),
)
)
def delete_group_with_members(self, group: UserGroup) -> None:
"""Remove all member relationships and delete the group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
)
self._session.delete(group)
def sync_group_external_id(
self, group_id: int, new_external_id: str | None
) -> None:
"""Create, update, or delete the external ID mapping for a group."""
mapping = self.get_group_mapping_by_group_id(group_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_group_mapping(
external_id=new_external_id, user_group_id=group_id
)
elif mapping:
self.delete_group_mapping(mapping.id)
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
"""Batch-fetch external IDs for a list of group IDs."""
if not group_ids:
return {}
mappings = self._session.scalars(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id.in_(group_ids)
)
).all()
return {m.user_group_id: m.external_id for m in mappings}
# ---------------------------------------------------------------------------
# Module-level helpers (used by DAL methods above)
# ---------------------------------------------------------------------------
def _apply_scim_string_op(
query: Select[tuple[User]] | Select[tuple[UserGroup]],
column: SQLColumnExpression[str],
scim_filter: ScimFilter,
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
SQLAlchemy's operators handle LIKE-pattern escaping internally.
"""
val = scim_filter.value
if scim_filter.operator == ScimFilterOperator.EQUAL:
return query.where(func.lower(column) == val.lower())
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
return query.where(column.icontains(val, autoescape=True))
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
return query.where(column.istartswith(val, autoescape=True))
else:
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")

View File

@@ -9,7 +9,6 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -266,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -275,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(

View File

@@ -50,12 +50,7 @@ def github_doc_sync(
**cc_pair.connector.connector_specific_config
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
github_connector.load_credentials(credential_json)
github_connector.load_credentials(cc_pair.credential.credential_json)
logger.info("GitHub connector credentials loaded successfully")
if not github_connector.github_client:
@@ -65,7 +60,21 @@ def github_doc_sync(
# Get all repositories from GitHub API
logger.info("Fetching all repositories from GitHub API")
try:
repos = github_connector.fetch_configured_repos()
repos = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(
github_connector.github_client
)
else:
# Single repository
repos = [
github_connector.get_github_repo(github_connector.github_client)
]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
logger.info(f"Found {len(repos)} repositories to check")
except Exception as e:

View File

@@ -18,12 +18,7 @@ def github_group_sync(
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
github_connector.load_credentials(credential_json)
github_connector.load_credentials(cc_pair.credential.credential_json)
if not github_connector.github_client:
raise ValueError("github_client is required")

View File

@@ -50,12 +50,7 @@ def gmail_doc_sync(
already populated.
"""
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
gmail_connector.load_credentials(credential_json)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback

View File

@@ -295,12 +295,7 @@ def gdrive_doc_sync(
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
google_drive_connector.load_credentials(credential_json)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)

View File

@@ -391,12 +391,7 @@ def gdrive_group_sync(
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
google_drive_connector.load_credentials(credential_json)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
admin_service = get_admin_service(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)

View File

@@ -24,12 +24,7 @@ def jira_doc_sync(
jira_connector = JiraConnector(
**cc_pair.connector.connector_specific_config,
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
jira_connector.load_credentials(credential_json)
jira_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -119,13 +119,8 @@ def jira_group_sync(
if not jira_base_url:
raise ValueError("No jira_base_url found in connector config")
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
jira_client = build_jira_client(
credentials=credential_json,
credentials=cc_pair.credential.credential_json,
jira_base=jira_base_url,
scoped_token=scoped_token,
)

View File

@@ -30,11 +30,7 @@ def get_any_salesforce_client_for_doc_id(
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = (
first_cc_pair.credential.credential_json.get_value(apply_mask=False)
if first_cc_pair.credential.credential_json
else {}
)
credential_json = first_cc_pair.credential.credential_json
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
@@ -162,11 +158,7 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
credential_json = cc_pair.credential.credential_json
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],

View File

@@ -24,12 +24,7 @@ def sharepoint_doc_sync(
sharepoint_connector = SharepointConnector(
**cc_pair.connector.connector_specific_config,
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
sharepoint_connector.load_credentials(credential_json)
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -6,7 +6,6 @@ from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
@@ -26,12 +25,7 @@ def sharepoint_group_sync(
# Create SharePoint connector instance and load credentials
connector = SharepointConnector(**connector_config)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
connector.load_credentials(credential_json)
connector.load_credentials(cc_pair.credential.credential_json)
if not connector.msal_app:
raise RuntimeError("MSAL app not initialized in connector")
@@ -47,27 +41,19 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,13 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -18,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -151,14 +151,9 @@ def slack_doc_sync(
tenant_id = get_current_tenant_id()
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
slack_client = SlackConnector.make_slack_web_client(
provider.get_provider_key(),
credential_json["slack_bot_token"],
cc_pair.credential.credential_json["slack_bot_token"],
SlackConnector.MAX_RETRIES,
r,
)

View File

@@ -63,14 +63,9 @@ def slack_group_sync(
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
slack_client = SlackConnector.make_slack_web_client(
provider.get_provider_key(),
credential_json["slack_bot_token"],
cc_pair.credential.credential_json["slack_bot_token"],
SlackConnector.MAX_RETRIES,
r,
)

View File

@@ -25,12 +25,7 @@ def teams_doc_sync(
teams_connector = TeamsConnector(
**cc_pair.connector.connector_specific_config,
)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
teams_connector.load_credentials(credential_json)
teams_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -163,11 +162,6 @@ def get_application() -> FastAPI:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -77,7 +77,7 @@ def stream_search_query(
# Get document index
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None, db_session)
document_index = get_default_document_index(search_settings, None)
# Determine queries to execute
original_query = request.search_query

View File

@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
# before bearer token configuration is complete
("/scim/v2/ServiceProviderConfig", {"GET"}),
("/scim/v2/ResourceTypes", {"GET"}),
("/scim/v2/Schemas", {"GET"}),
# needs to be accessible prior to user login
("/enterprise-settings", {"GET"}),
("/enterprise-settings/logo", {"GET"}),

View File

@@ -13,7 +13,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
@@ -203,63 +198,3 @@ def upload_custom_analytics_script(
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()
# ---------------------------------------------------------------------------
# SCIM token management
# ---------------------------------------------------------------------------
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
return ScimTokenResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
)
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.
Only one token is active at a time — creating a new token automatically
revokes all previous tokens. The raw token value is returned exactly once
in the response; it cannot be retrieved again.
"""
raw_token, hashed_token, token_display = generate_scim_token()
token = dal.create_token(
name=body.name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=user.id,
)
dal.commit()
return ScimTokenCreatedResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
raw_token=raw_token,
)

View File

@@ -270,11 +270,7 @@ def confluence_oauth_accessible_resources(
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
credential_dict = credential.credential_json
access_token = credential_dict["confluence_access_token"]
try:
@@ -341,12 +337,7 @@ def confluence_oauth_finalize(
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
existing_credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
new_credential_json: dict[str, Any] = dict(existing_credential_json)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url

View File

@@ -27,8 +27,6 @@ class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None

View File

@@ -26,7 +26,6 @@ from onyx.db.models import User
from onyx.llm.factory import get_default_llm
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.server.utils import get_json_line
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -67,13 +66,7 @@ def search_flow_classification(
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
@router.post(
"/send-search-message",
response_model=None,
dependencies=[Depends(require_vector_db)],
)
@router.post("/send-search-message", response_model=None)
def handle_send_search_message(
request: SendSearchQueryRequest,
user: User = Depends(current_user),

View File

@@ -1,689 +0,0 @@
"""SCIM 2.0 API endpoints (RFC 7644).
This module provides the FastAPI router for SCIM service discovery,
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
these endpoints to provision and manage users and groups.
Service discovery endpoints are unauthenticated — IdPs may probe them
before bearer token configuration is complete. All other endpoints
require a valid SCIM bearer token.
"""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@scim_router.get("/ServiceProviderConfig")
def get_service_provider_config() -> ScimServiceProviderConfig:
"""Advertise supported SCIM features (RFC 7643 §5)."""
return SERVICE_PROVIDER_CONFIG
@scim_router.get("/ResourceTypes")
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
@scim_router.get("/Schemas")
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> JSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
body = ScimError(status=str(status), detail=detail)
return JSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
except ValueError:
return _scim_error_response(404, f"User {user_id} not found")
user = dal.get_user(uid)
if not user:
return _scim_error_response(404, f"User {user_id} not found")
return user
def _scim_name_to_str(name: ScimName | None) -> str | None:
"""Extract a display name string from a SCIM name object.
Returns None if no name is provided, so the caller can decide
whether to update the user's personal_name.
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# ---------------------------------------------------------------------------
# User CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
)
try:
dal.add_user(user)
except IntegrityError:
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
dal.commit()
return _user_to_scim(user, external_id)
@scim_router.put("/Users/{user_id}", response_model=None)
def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
# Handle activation (need seat check) / deactivation
if user_resource.active and not user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
dal.commit()
return _user_to_scim(user, new_external_id)
@scim_router.patch("/Users/{user_id}", response_model=None)
def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
``PATCH {"active": false}`` rather than DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current = _user_to_scim(user, external_id)
try:
patched = apply_user_patch(patch_request.Operations, current)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.commit()
return _user_to_scim(user, patched.externalId)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
dal.deactivate_user(user)
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
dal.commit()
return Response(status_code=204)
# ---------------------------------------------------------------------------
# Group helpers
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
except ValueError:
return _scim_error_response(404, f"Group {group_id} not found")
group = dal.get_group(gid)
if not group:
return _scim_error_response(404, f"Group {group_id} not found")
return group
def _parse_member_uuids(
members: list[ScimGroupMember],
) -> tuple[list[UUID], str | None]:
"""Parse member value strings to UUIDs.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids: list[UUID] = []
for m in members:
try:
uuids.append(UUID(m.value))
except ValueError:
return [], f"Invalid member ID: {m.value}"
return uuids, None
def _validate_and_parse_members(
members: list[ScimGroupMember], dal: ScimDAL
) -> tuple[list[UUID], str | None]:
"""Parse and validate member UUIDs exist in the database.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids, err = _parse_member_uuids(members)
if err:
return [], err
if uuids:
missing = dal.validate_member_ids(uuids)
if missing:
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
return uuids, None
# ---------------------------------------------------------------------------
# Group CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
db_group = UserGroup(
name=group_resource.displayName,
is_up_to_date=True,
time_last_modified_by_user=func.now(),
)
try:
dal.add_group(db_group)
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
dal.upsert_group_members(db_group.id, member_uuids)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
dal.update_group(group, name=new_name)
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
missing = dal.validate_member_ids(add_uuids)
if missing:
return _scim_error_response(
400,
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
dal.commit()
return Response(status_code=204)
def _is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
try:
UUID(value)
return True
except ValueError:
return False

View File

@@ -1,104 +0,0 @@
"""SCIM bearer token authentication.
SCIM endpoints are authenticated via bearer tokens that admins create in the
Onyx UI. This module provides:
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
validates the token from the Authorization header.
- ``generate_scim_token``: Creates a new cryptographically random token
and returns the raw value, its SHA-256 hash, and a display suffix.
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
URL-safe base64 from ``secrets.token_urlsafe``.
The hash is stored in the ``scim_token`` table; the raw value is shown to
the admin exactly once at creation time.
"""
import hashlib
import secrets
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
def _hash_scim_token(token: str) -> str:
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def generate_scim_token() -> tuple[str, str, str]:
"""Generate a new SCIM bearer token.
Returns:
A tuple of ``(raw_token, hashed_token, token_display)`` where
``token_display`` is a masked version showing only the last 4 chars.
"""
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
hashed_token = _hash_scim_token(raw_token)
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
return raw_token, hashed_token, token_display
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
"""Extract and hash a SCIM token from the request Authorization header."""
return get_hashed_bearer_token_from_request(
request,
valid_prefixes=[SCIM_TOKEN_PREFIX],
hash_fn=_hash_scim_token,
)
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
def verify_scim_token(
request: Request,
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimToken:
"""FastAPI dependency that authenticates SCIM requests.
Extracts the bearer token from the Authorization header, hashes it,
looks it up in the database, and verifies it is active.
Note:
This dependency does NOT update ``last_used_at`` — the endpoint
should do that via ``ScimDAL.update_token_last_used()`` so the
timestamp write is part of the endpoint's transaction.
Raises:
HTTPException(401): If the token is missing, invalid, or inactive.
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise HTTPException(
status_code=401,
detail="Missing or invalid SCIM bearer token",
)
token = dal.get_token_by_hash(hashed)
if not token:
raise HTTPException(
status_code=401,
detail="Invalid SCIM bearer token",
)
if not token.is_active:
raise HTTPException(
status_code=401,
detail="SCIM token has been revoked",
)
return token

View File

@@ -1,96 +0,0 @@
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
resources before deciding whether to create or update them. For example, when
an admin assigns a user to the Onyx app, the IdP first checks whether that
user already exists::
GET /scim/v2/Users?filter=userName eq "john@example.com"
If zero results come back the IdP creates the user (``POST``); if a match is
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
The same pattern applies to groups (``displayName eq "Engineering"``).
This module parses the subset of the SCIM filter grammar that identity
providers actually send in practice:
attribute SP operator SP value
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
the parser returns ``None`` and the caller falls back to an unfiltered list.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
class ScimFilterOperator(str, Enum):
"""Supported SCIM filter operators."""
EQUAL = "eq"
CONTAINS = "co"
STARTS_WITH = "sw"
@dataclass(frozen=True, slots=True)
class ScimFilter:
"""Parsed SCIM filter expression."""
attribute: str
operator: ScimFilterOperator
value: str
# Matches: attribute operator "value" (with or without quotes around value)
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
_FILTER_RE = re.compile(
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
r'(?:"([^"]*)"' # quoted value
r"|'([^']*)')" # or single-quoted value
r"$",
re.IGNORECASE,
)
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
"""Parse a simple SCIM filter expression.
Args:
filter_string: Raw filter query parameter value, e.g.
``'userName eq "john@example.com"'``
Returns:
A ``ScimFilter`` if the expression is valid and uses a supported
operator, or ``None`` if the input is empty / missing.
Raises:
ValueError: If the filter string is present but malformed or uses
an unsupported operator.
"""
if not filter_string or not filter_string.strip():
return None
match = _FILTER_RE.match(filter_string.strip())
if not match:
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
return _build_filter(match, filter_string)
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
"""Extract fields from a regex match and construct a ScimFilter."""
attribute = match.group(1)
op_str = match.group(2).lower()
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
value = match.group(3) if match.group(3) is not None else match.group(4)
if value is None:
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
operator = ScimFilterOperator(op_str)
return ScimFilter(attribute=attribute, operator=operator, value=value)

View File

@@ -1,285 +0,0 @@
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
SCIM protocol schemas follow the wire format defined in:
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
# ---------------------------------------------------------------------------
# SCIM Schema URIs (RFC 7643 §8)
# Every SCIM JSON payload includes a "schemas" array identifying its type.
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
# ---------------------------------------------------------------------------
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
# ---------------------------------------------------------------------------
# SCIM Protocol Schemas
# ---------------------------------------------------------------------------
class ScimName(BaseModel):
"""User name components (RFC 7643 §4.1.1)."""
givenName: str | None = None
familyName: str | None = None
formatted: str | None = None
class ScimEmail(BaseModel):
"""Email sub-attribute (RFC 7643 §4.1.2)."""
value: str
type: str | None = None
primary: bool = False
class ScimMeta(BaseModel):
"""Resource metadata (RFC 7643 §3.1)."""
resourceType: str | None = None
created: datetime | None = None
lastModified: datetime | None = None
location: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
This is the JSON shape that IdPs send when creating/updating a user via
SCIM, and the shape we return in GET responses. Field names use camelCase
to match the SCIM wire format (not Python convention).
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
meta: ScimMeta | None = None
class ScimGroupMember(BaseModel):
"""Group member reference (RFC 7643 §4.2).
Represents a user within a SCIM group. The IdP sends these when adding
or removing users from groups. ``value`` is the Onyx user ID.
"""
value: str # User ID of the group member
display: str | None = None
class ScimGroupResource(BaseModel):
"""SCIM Group resource representation (RFC 7643 §4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
id: str | None = None
externalId: str | None = None
displayName: str
members: list[ScimGroupMember] = Field(default_factory=list)
meta: ScimMeta | None = None
class ScimListResponse(BaseModel):
"""Paginated list response (RFC 7644 §3.4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
totalResults: int
startIndex: int = 1
itemsPerPage: int = 100
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
class ScimPatchOperationType(str, Enum):
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
ADD = "add"
REPLACE = "replace"
REMOVE = "remove"
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).
IdPs use PATCH to make incremental changes — e.g. deactivating a user
(replace active=false) or adding/removing group members — instead of
replacing the entire resource with PUT.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
Operations: list[ScimPatchOperation]
class ScimError(BaseModel):
"""SCIM error response (RFC 7644 §3.12)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
status: str
detail: str | None = None
scimType: str | None = None
# ---------------------------------------------------------------------------
# Service Provider Configuration (RFC 7643 §5)
# ---------------------------------------------------------------------------
class ScimSupported(BaseModel):
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
supported: bool
class ScimFilterConfig(BaseModel):
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
supported: bool
maxResults: int = 100
class ScimServiceProviderConfig(BaseModel):
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
initial setup to discover which SCIM features our server supports
(e.g. PATCH yes, bulk no, filtering yes).
"""
schemas: list[str] = Field(
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
)
patch: ScimSupported = ScimSupported(supported=True)
bulk: ScimSupported = ScimSupported(supported=False)
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
changePassword: ScimSupported = ScimSupported(supported=False)
sort: ScimSupported = ScimSupported(supported=False)
etag: ScimSupported = ScimSupported(supported=False)
authenticationSchemes: list[dict[str, str]] = Field(
default_factory=lambda: [
{
"type": "oauthbearertoken",
"name": "OAuth Bearer Token",
"description": "Authentication scheme using a SCIM bearer token",
}
]
)
class ScimSchemaAttribute(BaseModel):
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
name: str
type: str
multiValued: bool = False
required: bool = False
description: str = ""
caseExact: bool = False
mutability: str = "readWrite"
returned: str = "default"
uniqueness: str = "none"
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
class ScimSchemaDefinition(BaseModel):
"""SCIM Schema definition (RFC 7643 §7).
Served at GET /scim/v2/Schemas. Describes the attributes available
on each resource type so IdPs know which fields they can provision.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
id: str
name: str
description: str
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
schema_: str = Field(alias="schema")
required: bool
class ScimResourceType(BaseModel):
"""SCIM ResourceType resource (RFC 7643 §6).
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str
name: str
endpoint: str
description: str | None = None
schema_: str = Field(alias="schema")
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Admin API Schemas (Onyx-internal, for SCIM token management)
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
# ---------------------------------------------------------------------------
class ScimTokenCreate(BaseModel):
"""Request to create a new SCIM bearer token."""
name: str
class ScimTokenResponse(BaseModel):
"""SCIM token metadata returned in list/get responses."""
id: int
name: str
token_display: str
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):
"""Response returned when a new SCIM token is created.
Includes the raw token value which is only available at creation time.
"""
raw_token: str

View File

@@ -1,256 +0,0 @@
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
Identity providers use PATCH to make incremental changes to SCIM resources
instead of replacing the entire resource with PUT. Common operations include:
- Deactivating a user: ``replace`` ``active`` with ``false``
- Adding group members: ``add`` to ``members``
- Removing group members: ``remove`` from ``members[value eq "..."]``
This module applies PATCH operations to Pydantic SCIM resource objects and
returns the modified result. It does NOT touch the database — the caller is
responsible for persisting changes.
"""
from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimUserResource
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
def __init__(self, detail: str, status: int = 400) -> None:
self.detail = detail
self.status = status
super().__init__(detail)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
name_data = data.get("name") or {}
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
data["name"] = name_data
return ScimUserResource.model_validate(data)
def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data)
def _set_user_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
) -> None:
"""Set a single field on user data by SCIM path."""
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
current_members: list[dict] = list(data.get("members") or [])
added_ids: list[str] = []
removed_ids: list[str] = []
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_group_remove(op, current_members, removed_ids)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on Group resource"
)
data["members"] = current_members
group = ScimGroupResource.model_validate(data)
return group, added_ids, removed_ids
def _apply_group_replace(
op: ScimPatchOperation,
data: dict,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, dict):
for key, val in op.value.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(op.value, current_members, added_ids, removed_ids)
return
_set_group_field(path, op.value, data)
def _replace_members(
value: str | list | dict | bool | None,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
removed_ids.extend(old_ids - new_ids)
added_ids.extend(new_ids - old_ids)
current_members[:] = value
def _set_group_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
) -> None:
"""Set a single field on group data by SCIM path."""
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(
op: ScimPatchOperation,
members: list[dict],
added_ids: list[str],
) -> None:
"""Add members to a group."""
path = (op.path or "").lower()
if path and path != "members":
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
existing_ids = {m["value"] for m in members}
for member_data in op.value:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)
added_ids.append(member_id)
existing_ids.add(member_id)
def _apply_group_remove(
op: ScimPatchOperation,
members: list[dict],
removed_ids: list[str],
) -> None:
"""Remove members from a group."""
if not op.path:
raise ScimPatchError("Remove operation requires a path")
match = _MEMBER_FILTER_RE.match(op.path)
if not match:
raise ScimPatchError(
f"Unsupported remove path '{op.path}'. "
'Expected: members[value eq "user-id"]'
)
target_id = match.group(1)
original_len = len(members)
members[:] = [m for m in members if m.get("value") != target_id]
if len(members) < original_len:
removed_ids.append(target_id)

View File

@@ -1,144 +0,0 @@
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaAttribute
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "User",
"name": "User",
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
}
)
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "Group",
"name": "Group",
"endpoint": "/scim/v2/Groups",
"description": "SCIM Group resource",
"schema": SCIM_GROUP_SCHEMA,
}
)
USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_USER_SCHEMA,
name="User",
description="SCIM core User schema",
attributes=[
ScimSchemaAttribute(
name="userName",
type="string",
required=True,
uniqueness="server",
description="Unique identifier for the user, typically an email address.",
),
ScimSchemaAttribute(
name="name",
type="complex",
description="The components of the user's name.",
subAttributes=[
ScimSchemaAttribute(
name="givenName",
type="string",
description="The user's first name.",
),
ScimSchemaAttribute(
name="familyName",
type="string",
description="The user's last name.",
),
ScimSchemaAttribute(
name="formatted",
type="string",
description="The full name, including all middle names and titles.",
),
],
),
ScimSchemaAttribute(
name="emails",
type="complex",
multiValued=True,
description="Email addresses for the user.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Email address value.",
),
ScimSchemaAttribute(
name="type",
type="string",
description="Label for this email (e.g. 'work').",
),
ScimSchemaAttribute(
name="primary",
type="boolean",
description="Whether this is the primary email.",
),
],
),
ScimSchemaAttribute(
name="active",
type="boolean",
description="Whether the user account is active.",
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",
description="SCIM core Group schema",
attributes=[
ScimSchemaAttribute(
name="displayName",
type="string",
required=True,
description="Human-readable name for the group.",
),
ScimSchemaAttribute(
name="members",
type="complex",
multiValued=True,
description="Members of the group.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="User ID of the group member.",
),
ScimSchemaAttribute(
name="display",
type="string",
mutability="readOnly",
description="Display name of the group member.",
),
],
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)

View File

@@ -37,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -11,7 +11,6 @@ from onyx.db.models import OAuthUserToken
from onyx.db.oauth_config import get_user_oauth_token
from onyx.db.oauth_config import upsert_user_oauth_token
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
logger = setup_logger()
@@ -34,10 +33,7 @@ class OAuthTokenManager:
if not user_token:
return None
if not user_token.token_data:
return None
token_data = self._unwrap_token_data(user_token.token_data)
token_data = user_token.token_data
# Check if token is expired
if OAuthTokenManager.is_token_expired(token_data):
@@ -55,10 +51,7 @@ class OAuthTokenManager:
def refresh_token(self, user_token: OAuthUserToken) -> str:
"""Refresh access token using refresh token"""
if not user_token.token_data:
raise ValueError("No token data available for refresh")
token_data = self._unwrap_token_data(user_token.token_data)
token_data = user_token.token_data
response = requests.post(
self.oauth_config.token_url,
@@ -160,11 +153,3 @@ class OAuthTokenManager:
separator = "&" if "?" in oauth_config.authorization_url else "?"
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
) -> dict[str, Any]:
if isinstance(token_data, SensitiveValue):
return token_data.get_value(apply_mask=False)
return token_data

View File

@@ -1,9 +1,7 @@
import uuid
from enum import Enum
from typing import Any
from fastapi_users import schemas
from typing_extensions import override
class UserRole(str, Enum):
@@ -43,21 +41,8 @@ class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
captcha_token: str | None = None
@override
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
return d
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole

View File

@@ -121,7 +121,6 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -138,8 +137,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -211,34 +208,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
if not workspace_invite_only_enabled():
whitelist = get_invited_users()
if not whitelist:
return
whitelist = get_invited_users()
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -255,13 +240,7 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -1480,7 +1459,6 @@ def get_anonymous_user() -> User:
is_superuser=False,
role=UserRole.LIMITED,
use_memories=False,
enable_memory_tool=False,
)
return user
@@ -1671,10 +1649,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
@@ -526,12 +525,6 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if DISABLE_VECTOR_DB:
logger.info(
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
)
return
if not wait_for_vespa_with_timeout():
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
@@ -573,31 +566,3 @@ class LivenessProbe(bootsteps.StartStopStep):
def get_bootsteps() -> list[type]:
return [LivenessProbe]
# Task modules that require a vector DB (Vespa/OpenSearch).
# When DISABLE_VECTOR_DB is True these are excluded from autodiscover lists.
_VECTOR_DB_TASK_MODULES: set[str] = {
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.docfetching",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.opensearch_migration",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.hierarchyfetching",
# EE modules that are vector-DB-dependent
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
}
# NOTE: "onyx.background.celery.tasks.shared" is intentionally NOT in the set
# above. It contains celery_beat_heartbeat (which only writes to Redis) alongside
# document cleanup tasks. The cleanup tasks won't be invoked in minimal mode
# because the periodic tasks that trigger them are in other filtered modules.
def filter_task_modules(modules: list[str]) -> list[str]:
"""Remove vector-DB-dependent task modules when DISABLE_VECTOR_DB is True."""
if not DISABLE_VECTOR_DB:
return modules
return [m for m in modules if m not in _VECTOR_DB_TASK_MODULES]

View File

@@ -118,25 +118,23 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.opensearch_migration",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -96,9 +96,7 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.docfetching",
]
)
[
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -107,9 +107,7 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.docprocessing",
]
)
[
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -96,12 +96,10 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.pruning",
# Sandbox tasks (file sync, cleanup)
"onyx.server.features.build.sandbox.tasks",
"onyx.background.celery.tasks.hierarchyfetching",
]
)
[
"onyx.background.celery.tasks.pruning",
# Sandbox tasks (file sync, cleanup)
"onyx.server.features.build.sandbox.tasks",
"onyx.background.celery.tasks.hierarchyfetching",
]
)

View File

@@ -110,16 +110,13 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.opensearch_migration",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
[
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.docprocessing",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -94,9 +94,7 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.monitoring",
]
)
[
"onyx.background.celery.tasks.monitoring",
]
)

View File

@@ -314,18 +314,17 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_processing",
]
)
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.opensearch_migration",
]
)

View File

@@ -107,9 +107,7 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.user_file_processing",
]
)
[
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -1,30 +1,25 @@
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
import httpx
from pydantic import BaseModel
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
@@ -34,129 +29,63 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
CT = TypeVar("CT", bound=ConnectorCheckpoint)
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector."""
doc_ids: set[str]
hierarchy_nodes: list[HierarchyNode]
def _checkpointed_batched_items(
connector: CheckpointedConnector[CT],
start: float,
end: float,
) -> Generator[list[Document | HierarchyNode | ConnectorFailure], None, None]:
"""Loop through all checkpoint steps and yield batched items.
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
checkpoint call may only initialize internal state without yielding
any documents. This function loops until checkpoint.has_more is False
to ensure all items are collected across every step.
"""
checkpoint = connector.build_dummy_checkpoint()
while True:
checkpoint_output = connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
batch: list[Document | HierarchyNode | ConnectorFailure] = []
for document, hierarchy_node, failure, next_checkpoint in wrapper(
checkpoint_output
):
if document is not None:
batch.append(document)
elif hierarchy_node is not None:
batch.append(hierarchy_node)
elif failure is not None:
batch.append(failure)
if next_checkpoint is not None:
checkpoint = next_checkpoint
if batch:
yield batch
if not checkpoint.has_more:
break
def _get_failure_id(failure: ConnectorFailure) -> str | None:
"""Extract the document/entity ID from a ConnectorFailure."""
if failure.failed_document:
return failure.failed_document.document_id
if failure.failed_entity:
return failure.failed_entity.entity_id
return None
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID set so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: set[str] = set()
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
ids.add(item.raw_node_id)
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids.add(failed_id)
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
ids.add(item.id)
return ids, hierarchy_nodes
def document_batch_to_ids(
doc_batch: (
Iterator[list[Document | HierarchyNode]]
| Iterator[list[SlimDocument | HierarchyNode]]
),
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
for doc in doc_list
}
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> SlimConnectorExtractionResult:
) -> set[str]:
"""
Extract document IDs and hierarchy nodes from a runnable connector.
Hierarchy nodes yielded alongside documents/slim docs are collected and
returned in the result. ConnectorFailure items have their IDs preserved
so that failed-to-retrieve documents are not accidentally pruned.
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
raw_batch_generator: (
Iterator[Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure]]
| None
) = None
doc_batch_id_generator = None
if isinstance(runnable_connector, SlimConnector):
raw_batch_generator = runnable_connector.retrieve_all_slim_docs()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs()
)
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
raw_batch_generator = runnable_connector.retrieve_all_slim_docs_perm_sync()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs_perm_sync()
)
# If the connector isn't slim, fall back to running it normally to get ids
elif isinstance(runnable_connector, LoadConnector):
raw_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
raw_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
raw_batch_generator = _checkpointed_batched_items(
runnable_connector, start, end
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
@@ -170,24 +99,19 @@ def extract_ids_from_runnable_connector(
else lambda x: x
)
# process raw batches to extract both IDs and hierarchy nodes
for doc_list in raw_batch_generator:
if callback and callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
all_hierarchy_nodes.extend(batch_nodes)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
return SlimConnectorExtractionResult(
doc_ids=all_connector_doc_ids,
hierarchy_nodes=all_hierarchy_nodes,
)
return all_connector_doc_ids
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:

View File

@@ -6,7 +6,6 @@ from celery.schedules import crontab
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
@@ -216,39 +215,36 @@ if SCHEDULED_EVAL_DATASET_NAMES:
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
beat_task_templates.append(
{
"name": "migrate-chunks-from-vespa-to-opensearch",
"task": OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
"name": "check-for-documents-for-opensearch-migration",
"task": OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
# Try to enqueue an invocation of this task with this frequency.
"schedule": timedelta(seconds=120), # 2 minutes
"options": {
"priority": OnyxCeleryPriority.LOW,
# If the task was not dequeued in this time, revoke it.
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
},
}
)
# Beat task names that require a vector DB. Filtered out when DISABLE_VECTOR_DB.
_VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
"check-for-indexing",
"check-for-connector-deletion",
"check-for-vespa-sync",
"check-for-pruning",
"check-for-hierarchy-fetching",
"check-for-checkpoint-cleanup",
"check-for-index-attempt-cleanup",
"check-for-doc-permissions-sync",
"check-for-external-group-sync",
"check-for-documents-for-opensearch-migration",
"migrate-documents-from-vespa-to-opensearch",
}
if DISABLE_VECTOR_DB:
beat_task_templates = [
t for t in beat_task_templates if t["name"] not in _VECTOR_DB_BEAT_TASK_NAMES
]
beat_task_templates.append(
{
"name": "migrate-documents-from-vespa-to-opensearch",
"task": OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
# Try to enqueue an invocation of this task with this frequency.
# NOTE: If MIGRATION_TASK_SOFT_TIME_LIMIT_S is greater than this
# value and the task is maximally busy, we can expect to see some
# enqueued tasks be revoked over time. This is ok; by erring on the
# side of "there will probably always be at least one task of this
# type in the queue", we are minimizing this task's idleness while
# still giving chances for other tasks to execute.
"schedule": timedelta(seconds=120), # 2 minutes
"options": {
"priority": OnyxCeleryPriority.LOW,
# If the task was not dequeued in this time, revoke it.
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:

View File

@@ -37,7 +37,6 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.parent_pid = parent_pid
@@ -52,29 +51,11 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
if bool(self.redis_connector.stop.fenced):
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"IndexingCallback Docprocessing - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
return bool(self.redis_connector.stop.fenced)
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
"""Amount isn't used yet."""

Some files were not shown because too many files have changed in this diff Show More