mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-23 10:45:44 +00:00
Compare commits
1 Commits
csv_render
...
jamison/od
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bcd28e19b |
@@ -1 +0,0 @@
|
||||
../.cursor/skills
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"Playwright": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"@playwright/mcp"
|
||||
]
|
||||
},
|
||||
"Linear": {
|
||||
"url": "https://mcp.linear.app/mcp"
|
||||
},
|
||||
"Figma": {
|
||||
"url": "https://mcp.figma.com/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,248 +0,0 @@
|
||||
---
|
||||
name: playwright-e2e-tests
|
||||
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
|
||||
---
|
||||
|
||||
# Playwright E2E Tests
|
||||
|
||||
## Project Layout
|
||||
|
||||
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
|
||||
- **Config**: `web/playwright.config.ts`
|
||||
- **Utilities**: `web/tests/e2e/utils/`
|
||||
- **Constants**: `web/tests/e2e/constants.ts`
|
||||
- **Global setup**: `web/tests/e2e/global-setup.ts`
|
||||
- **Output**: `web/output/playwright/`
|
||||
|
||||
## Imports
|
||||
|
||||
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
|
||||
```
|
||||
|
||||
All new files should be `.ts`, not `.js`.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test file
|
||||
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
|
||||
|
||||
# Run a specific project
|
||||
npx playwright test --project admin
|
||||
npx playwright test --project exclusive
|
||||
```
|
||||
|
||||
## Test Projects
|
||||
|
||||
| Project | Description | Parallelism |
|
||||
|---------|-------------|-------------|
|
||||
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
|
||||
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
|
||||
|
||||
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
|
||||
|
||||
## Authentication
|
||||
|
||||
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
|
||||
|
||||
- Server readiness check (polls health endpoint, 60s timeout)
|
||||
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
|
||||
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
|
||||
- Setting display name to `"worker"` for each worker user
|
||||
- Promoting admin2 to admin role
|
||||
- Ensuring a public LLM provider exists
|
||||
|
||||
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
|
||||
|
||||
When a test needs a different user, use API-based login — never drive the login UI:
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin2");
|
||||
|
||||
// Log in as the worker-specific user (preferred for test isolation):
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
Tests start pre-authenticated as admin — navigate and test directly:
|
||||
|
||||
```typescript
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
test.describe("Feature Name", () => {
|
||||
test("should describe expected behavior clearly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// Already authenticated as admin — go straight to testing
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
|
||||
|
||||
```typescript
|
||||
import { test } from "@playwright/test";
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.beforeEach(async ({ page }, testInfo) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
});
|
||||
```
|
||||
|
||||
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
|
||||
|
||||
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
|
||||
|
||||
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
|
||||
|
||||
## Key Utilities
|
||||
|
||||
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
|
||||
|
||||
Backend API client for test setup/teardown. Key methods:
|
||||
|
||||
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
|
||||
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
|
||||
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
|
||||
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
|
||||
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
|
||||
- **Chat**: `createChatSession()`, `deleteChatSession()`
|
||||
|
||||
### `chatActions` (`@tests/e2e/utils/chatActions`)
|
||||
|
||||
- `sendMessage(page, message)` — sends a message and waits for AI response
|
||||
- `startNewChat(page)` — clicks new-chat button and waits for intro
|
||||
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
|
||||
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
|
||||
- `switchModel(page, modelName)` — switches LLM model via popover
|
||||
|
||||
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
|
||||
|
||||
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
|
||||
- `expectElementScreenshot(locator, { name, mask?, hide? })`
|
||||
- Controlled by `VISUAL_REGRESSION=true` env var
|
||||
|
||||
### `theme` (`@tests/e2e/utils/theme`)
|
||||
|
||||
- `THEMES` — `["light", "dark"] as const` array for iterating over both themes
|
||||
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
|
||||
|
||||
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
|
||||
|
||||
```typescript
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Feature (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
test("renders correctly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await expectScreenshot(page, { name: `feature-${theme}` });
|
||||
});
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### `tools` (`@tests/e2e/utils/tools`)
|
||||
|
||||
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
|
||||
- `openActionManagement(page)` — opens the tool management popover
|
||||
|
||||
## Locator Strategy
|
||||
|
||||
Use locators in this priority order:
|
||||
|
||||
1. **`data-testid` / `aria-label`** — preferred for Onyx components
|
||||
```typescript
|
||||
page.getByTestId("AppSidebar/new-session")
|
||||
page.getByLabel("admin-page-title")
|
||||
```
|
||||
|
||||
2. **Role-based** — for standard HTML elements
|
||||
```typescript
|
||||
page.getByRole("button", { name: "Create" })
|
||||
page.getByRole("dialog")
|
||||
```
|
||||
|
||||
3. **Text/Label** — for visible text content
|
||||
```typescript
|
||||
page.getByText("Custom Assistant")
|
||||
page.getByLabel("Email")
|
||||
```
|
||||
|
||||
4. **CSS selectors** — last resort, only when above won't work
|
||||
```typescript
|
||||
page.locator('input[name="name"]')
|
||||
page.locator("#onyx-chat-input-textarea")
|
||||
```
|
||||
|
||||
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
|
||||
|
||||
## Assertions
|
||||
|
||||
Use web-first assertions — they auto-retry until the condition is met:
|
||||
|
||||
```typescript
|
||||
// Visibility
|
||||
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Text content
|
||||
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
|
||||
|
||||
// Count
|
||||
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
|
||||
|
||||
// URL
|
||||
await expect(page).toHaveURL(/chatId=/);
|
||||
|
||||
// Element state
|
||||
await expect(toggle).toBeChecked();
|
||||
await expect(button).toBeEnabled();
|
||||
```
|
||||
|
||||
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
|
||||
|
||||
## Waiting Strategy
|
||||
|
||||
```typescript
|
||||
// Wait for load state after navigation
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Wait for specific element
|
||||
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Wait for URL change
|
||||
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
|
||||
|
||||
// Wait for network response
|
||||
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
|
||||
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
|
||||
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
|
||||
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
|
||||
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
|
||||
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
|
||||
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
|
||||
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
|
||||
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
|
||||
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does
|
||||
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -640,7 +640,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
@@ -722,7 +721,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
@@ -45,6 +45,9 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -115,10 +118,9 @@ jobs:
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -127,6 +129,7 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
|
||||
5
.github/workflows/pr-helm-chart-testing.yml
vendored
5
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,8 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
@@ -91,7 +92,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
|
||||
175
.github/workflows/pr-integration-tests.yml
vendored
175
.github/workflows/pr-integration-tests.yml
vendored
@@ -46,7 +46,6 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
editions: ${{ steps.set-editions.outputs.editions }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -57,7 +56,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" ! -name "no_vectordb" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -73,16 +72,6 @@ jobs:
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Determine editions to test
|
||||
id: set-editions
|
||||
run: |
|
||||
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
@@ -278,7 +267,7 @@ jobs:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
|
||||
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
@@ -286,7 +275,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -310,11 +298,12 @@ jobs:
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
EDITION: ${{ matrix.edition }}
|
||||
run: |
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -323,20 +312,11 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
# EE-only config
|
||||
if [ "$EDITION" = "ee" ]; then
|
||||
cat <<EOF >> deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -399,14 +379,14 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -464,143 +444,10 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
no-vectordb-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_no_vectordb
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "API server is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error; retrying..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running no-vectordb integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
|
||||
- name: Upload logs (no-vectordb)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
@@ -740,7 +587,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
443
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
443
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
@@ -0,0 +1,443 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
wait_for_service() {
|
||||
local url=$1
|
||||
local label=$2
|
||||
local timeout=${3:-300} # default 5 minutes
|
||||
local start_time
|
||||
start_time=$(date +%s)
|
||||
|
||||
while true; do
|
||||
local current_time
|
||||
current_time=$(date +%s)
|
||||
local elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
local response
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "${label} is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
|
||||
else
|
||||
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
231
.github/workflows/pr-playwright-tests.yml
vendored
231
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,9 +22,6 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -55,9 +52,6 @@ env:
|
||||
MCP_SERVER_PUBLIC_HOST: host.docker.internal
|
||||
MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
|
||||
|
||||
# Visual regression S3 bucket (shared across all jobs)
|
||||
PLAYWRIGHT_S3_BUCKET: onyx-playwright-artifacts
|
||||
|
||||
jobs:
|
||||
build-web-image:
|
||||
runs-on:
|
||||
@@ -245,9 +239,6 @@ jobs:
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests (${{ matrix.project }})
|
||||
permissions:
|
||||
id-token: write # Required for OIDC-based AWS credential exchange (S3 access)
|
||||
contents: read
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=8cpu-linux-arm64
|
||||
@@ -303,7 +294,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
@@ -438,6 +428,8 @@ jobs:
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
@@ -445,134 +437,9 @@ jobs:
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
name: playwright-test-results-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
path: ./web/test-results/
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload screenshots
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/screenshots/
|
||||
retention-days: 30
|
||||
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Determine baseline revision
|
||||
if: always()
|
||||
id: baseline-rev
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
MERGE_GROUP_BASE_REF: ${{ github.event.merge_group.base_ref }}
|
||||
GH_REF: ${{ github.ref }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ "${EVENT_NAME}" = "pull_request" ]; then
|
||||
# PRs compare against the base branch (e.g. main, release/2.5)
|
||||
echo "rev=${BASE_REF}" >> "$GITHUB_OUTPUT"
|
||||
elif [ "${EVENT_NAME}" = "merge_group" ]; then
|
||||
# Merge queue compares against the target branch (e.g. refs/heads/main -> main)
|
||||
echo "rev=${MERGE_GROUP_BASE_REF#refs/heads/}" >> "$GITHUB_OUTPUT"
|
||||
elif [[ "${GH_REF}" == refs/tags/* ]]; then
|
||||
# Tag builds compare against the tag name
|
||||
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
# Push builds (main, release/*) compare against the branch name
|
||||
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Generate screenshot diff report
|
||||
if: always()
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods screenshot-diff compare \
|
||||
--project "${PROJECT}" \
|
||||
--rev "${BASELINE_REV}"
|
||||
|
||||
- name: Upload visual diff report to S3
|
||||
if: always()
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
SUMMARY_FILE="web/output/screenshot-diff/${PROJECT}/summary.json"
|
||||
if [ ! -f "${SUMMARY_FILE}" ]; then
|
||||
echo "No summary file found — skipping S3 upload."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
|
||||
if [ "${HAS_DIFF}" != "true" ]; then
|
||||
echo "No visual differences for ${PROJECT} — skipping S3 upload."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
aws s3 sync "web/output/screenshot-diff/${PROJECT}/" \
|
||||
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
|
||||
|
||||
- name: Upload visual diff summary
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-summary-${{ matrix.project }}
|
||||
path: ./web/output/screenshot-diff/${{ matrix.project }}/summary.json
|
||||
if-no-files-found: ignore
|
||||
retention-days: 5
|
||||
|
||||
- name: Upload visual diff report artifact
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/screenshot-diff/${{ matrix.project }}/
|
||||
if-no-files-found: ignore
|
||||
retention-days: 30
|
||||
|
||||
- name: Update S3 baselines
|
||||
if: >-
|
||||
success() && (
|
||||
github.ref == 'refs/heads/main' ||
|
||||
startsWith(github.ref, 'refs/heads/release/') ||
|
||||
startsWith(github.ref, 'refs/tags/v') ||
|
||||
(
|
||||
github.event_name == 'merge_group' && (
|
||||
github.event.merge_group.base_ref == 'refs/heads/main' ||
|
||||
startsWith(github.event.merge_group.base_ref, 'refs/heads/release/')
|
||||
)
|
||||
)
|
||||
)
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
|
||||
run: |
|
||||
if [ -d "web/output/screenshots/" ] && [ "$(ls -A web/output/screenshots/)" ]; then
|
||||
uv run --no-sync --with onyx-devtools ods screenshot-diff upload-baselines \
|
||||
--project "${PROJECT}" \
|
||||
--rev "${BASELINE_REV}" \
|
||||
--delete
|
||||
else
|
||||
echo "No screenshots to upload for ${PROJECT} — skipping baseline update."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
@@ -590,98 +457,6 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: >-
|
||||
always() &&
|
||||
github.event_name == 'pull_request' &&
|
||||
needs.playwright-tests.result != 'cancelled'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
- name: Post combined PR comment
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
REPO: ${{ github.repository }}
|
||||
S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
run: |
|
||||
MARKER="<!-- visual-regression-report -->"
|
||||
|
||||
# Build the markdown table from all summary files
|
||||
TABLE_HEADER="| Project | Changed | Added | Removed | Unchanged | Report |"
|
||||
TABLE_DIVIDER="|---------|---------|-------|---------|-----------|--------|"
|
||||
TABLE_ROWS=""
|
||||
HAS_ANY_SUMMARY=false
|
||||
|
||||
for SUMMARY_DIR in summaries/screenshot-diff-summary-*/; do
|
||||
SUMMARY_FILE="${SUMMARY_DIR}summary.json"
|
||||
if [ ! -f "${SUMMARY_FILE}" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
HAS_ANY_SUMMARY=true
|
||||
PROJECT=$(jq -r '.project' "${SUMMARY_FILE}")
|
||||
CHANGED=$(jq -r '.changed' "${SUMMARY_FILE}")
|
||||
ADDED=$(jq -r '.added' "${SUMMARY_FILE}")
|
||||
REMOVED=$(jq -r '.removed' "${SUMMARY_FILE}")
|
||||
UNCHANGED=$(jq -r '.unchanged' "${SUMMARY_FILE}")
|
||||
TOTAL=$(jq -r '.total' "${SUMMARY_FILE}")
|
||||
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
|
||||
|
||||
if [ "${TOTAL}" = "0" ]; then
|
||||
REPORT_LINK="_No screenshots_"
|
||||
elif [ "${HAS_DIFF}" = "true" ]; then
|
||||
REPORT_URL="https://${S3_BUCKET}.s3.us-east-2.amazonaws.com/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/index.html"
|
||||
REPORT_LINK="[View Report](${REPORT_URL})"
|
||||
else
|
||||
REPORT_LINK="✅ No changes"
|
||||
fi
|
||||
|
||||
TABLE_ROWS="${TABLE_ROWS}| \`${PROJECT}\` | ${CHANGED} | ${ADDED} | ${REMOVED} | ${UNCHANGED} | ${REPORT_LINK} |\n"
|
||||
done
|
||||
|
||||
if [ "${HAS_ANY_SUMMARY}" = "false" ]; then
|
||||
echo "No visual diff summaries found — skipping PR comment."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BODY=$(printf '%s\n' \
|
||||
"${MARKER}" \
|
||||
"### 🖼️ Visual Regression Report" \
|
||||
"" \
|
||||
"${TABLE_HEADER}" \
|
||||
"${TABLE_DIVIDER}" \
|
||||
"$(printf '%b' "${TABLE_ROWS}")")
|
||||
|
||||
# Upsert: find existing comment with the marker, or create a new one
|
||||
EXISTING_COMMENT_ID=$(gh api \
|
||||
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
--jq ".[] | select(.body | startswith(\"${MARKER}\")) | .id" \
|
||||
2>/dev/null | head -1)
|
||||
|
||||
if [ -n "${EXISTING_COMMENT_ID}" ]; then
|
||||
gh api \
|
||||
--method PATCH \
|
||||
"repos/${REPO}/issues/comments/${EXISTING_COMMENT_ID}" \
|
||||
-f body="${BODY}"
|
||||
else
|
||||
gh api \
|
||||
--method POST \
|
||||
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
-f body="${BODY}"
|
||||
fi
|
||||
|
||||
playwright-required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
290
.github/workflows/sandbox-deployment.yml
vendored
290
.github/workflows/sandbox-deployment.yml
vendored
@@ -1,290 +0,0 @@
|
||||
name: Build and Push Sandbox Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "experimental-cc4a.*"
|
||||
|
||||
# Restrictive defaults; jobs declare what they need.
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
check-sandbox-changes:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
sandbox-changed: ${{ steps.check.outputs.sandbox-changed }}
|
||||
new-version: ${{ steps.version.outputs.new-version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for sandbox-relevant file changes
|
||||
id: check
|
||||
run: |
|
||||
# Get the previous tag to diff against
|
||||
CURRENT_TAG="${GITHUB_REF_NAME}"
|
||||
PREVIOUS_TAG=$(git tag --sort=-creatordate | grep '^experimental-cc4a\.' | grep -v "^${CURRENT_TAG}$" | head -n 1)
|
||||
|
||||
if [ -z "$PREVIOUS_TAG" ]; then
|
||||
echo "No previous experimental-cc4a tag found, building unconditionally"
|
||||
echo "sandbox-changed=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Comparing ${PREVIOUS_TAG}..${CURRENT_TAG}"
|
||||
|
||||
# Check if any sandbox-relevant files changed
|
||||
SANDBOX_PATHS=(
|
||||
"backend/onyx/server/features/build/sandbox/"
|
||||
)
|
||||
|
||||
CHANGED=false
|
||||
for path in "${SANDBOX_PATHS[@]}"; do
|
||||
if git diff --name-only "${PREVIOUS_TAG}..${CURRENT_TAG}" -- "$path" | grep -q .; then
|
||||
echo "Changes detected in: $path"
|
||||
CHANGED=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
echo "sandbox-changed=$CHANGED" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Determine new sandbox version
|
||||
id: version
|
||||
if: steps.check.outputs.sandbox-changed == 'true'
|
||||
run: |
|
||||
# Query Docker Hub for the latest versioned tag
|
||||
LATEST_TAG=$(curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags?page_size=100" \
|
||||
| jq -r '.results[].name' \
|
||||
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
|
||||
| sort -V \
|
||||
| tail -n 1)
|
||||
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No existing version tags found on Docker Hub, starting at 0.1.1"
|
||||
NEW_VERSION="0.1.1"
|
||||
else
|
||||
CURRENT_VERSION="${LATEST_TAG#v}"
|
||||
echo "Latest version on Docker Hub: $CURRENT_VERSION"
|
||||
|
||||
# Increment patch version
|
||||
MAJOR=$(echo "$CURRENT_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$CURRENT_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$CURRENT_VERSION" | cut -d. -f3)
|
||||
NEW_PATCH=$((PATCH + 1))
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.${NEW_PATCH}"
|
||||
fi
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new-version=$NEW_VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-sandbox-amd64:
|
||||
needs: check-sandbox-changes
|
||||
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-sandbox-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
build-sandbox-arm64:
|
||||
needs: check-sandbox-changes
|
||||
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-sandbox-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-sandbox:
|
||||
needs:
|
||||
- check-sandbox-changes
|
||||
- build-sandbox-amd64
|
||||
- build-sandbox-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-sandbox
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 30
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=v${{ needs.check-sandbox-changes.outputs.new-version }}
|
||||
type=raw,value=latest
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-sandbox-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-sandbox-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
13
.github/workflows/zizmor.yml
vendored
13
.github/workflows/zizmor.yml
vendored
@@ -5,8 +5,6 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
paths:
|
||||
- ".github/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
@@ -23,18 +21,29 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,8 +6,6 @@
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
!/.cursor/mcp.json
|
||||
!/.cursor/skills/
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
|
||||
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -275,7 +275,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -419,7 +419,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching"
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,10 +2,7 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -474,7 +474,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
logger.debug("run_migrations_online starting (pytest-alembic mode).")
|
||||
logger.info("run_migrations_online starting (pytest-alembic mode).")
|
||||
|
||||
# For pytest-alembic, we use the default schema (public)
|
||||
schema_name = context.config.attributes.get(
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""add scim_username to scim_user_mapping
|
||||
|
||||
Revision ID: 0bb4558f35df
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-20 10:45:30.340188
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0bb4558f35df"
|
||||
down_revision = "631fd2504136"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_username", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_username")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""add default_app_mode to user
|
||||
|
||||
Revision ID: 114a638452db
|
||||
Revises: feead2911109
|
||||
Revises: d56ffa94ca32
|
||||
Create Date: 2026-02-09 18:57:08.274640
|
||||
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ import sqlalchemy as sa
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from httpx import HTTPStatusError
|
||||
import httpx
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
@@ -518,11 +519,15 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
def upgrade() -> None:
|
||||
if SKIP_CANON_DRIVE_IDS:
|
||||
return
|
||||
current_search_settings, _ = active_search_settings()
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
future_search_settings,
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(current_search_settings, "index_name"):
|
||||
index_name = current_search_settings.index_name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Migrate to contextual rag model
|
||||
|
||||
Revision ID: 19c0ccb01687
|
||||
Revises: 9c54986124c6
|
||||
Create Date: 2026-02-12 11:21:41.798037
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "19c0ccb01687"
|
||||
down_revision = "9c54986124c6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
|
||||
# when the table was created with only CHAT/VISION values.
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=20),
|
||||
existing_type=sa.String(length=10),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# For every search_settings row that has contextual rag configured,
|
||||
# create an llm_model_flow entry. is_default is TRUE if the row
|
||||
# belongs to the PRESENT search settings, FALSE otherwise.
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
|
||||
SELECT DISTINCT
|
||||
'CONTEXTUAL_RAG',
|
||||
mc.id,
|
||||
(ss.status = 'PRESENT')
|
||||
FROM search_settings ss
|
||||
JOIN llm_provider lp
|
||||
ON lp.name = ss.contextual_rag_llm_provider
|
||||
JOIN model_configuration mc
|
||||
ON mc.llm_provider_id = lp.id
|
||||
AND mc.name = ss.contextual_rag_llm_name
|
||||
WHERE ss.enable_contextual_rag = TRUE
|
||||
AND ss.contextual_rag_llm_name IS NOT NULL
|
||||
AND ss.contextual_rag_llm_provider IS NOT NULL
|
||||
ON CONFLICT (llm_model_flow_type, model_configuration_id)
|
||||
DO UPDATE SET is_default = EXCLUDED.is_default
|
||||
WHERE EXCLUDED.is_default = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM llm_model_flow
|
||||
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=10),
|
||||
existing_type=sa.String(length=20),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add approx_chunk_count_in_vespa to opensearch tenant migration
|
||||
|
||||
Revision ID: 631fd2504136
|
||||
Revises: c7f2e1b4a9d3
|
||||
Create Date: 2026-02-18 21:07:52.831215
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "631fd2504136"
|
||||
down_revision = "c7f2e1b4a9d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"approx_chunk_count_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")
|
||||
@@ -16,6 +16,7 @@ from typing import Generator
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -125,11 +126,14 @@ def remove_old_tags() -> None:
|
||||
the document got reindexed, the old tag would not be removed.
|
||||
This function removes those old tags by comparing it against the tags in vespa.
|
||||
"""
|
||||
current_search_settings, _ = active_search_settings()
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings, future_search_settings
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(current_search_settings, "index_name"):
|
||||
index_name = current_search_settings.index_name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"""add chunk error and vespa count columns to opensearch tenant migration
|
||||
|
||||
Revision ID: 93c15d6a6fbb
|
||||
Revises: d3fd499c829c
|
||||
Create Date: 2026-02-11 23:07:34.576725
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93c15d6a6fbb"
|
||||
down_revision = "d3fd499c829c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_errored",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_in_vespa")
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_errored")
|
||||
@@ -1,124 +0,0 @@
|
||||
"""add_scim_tables
|
||||
|
||||
Revision ID: 9c54986124c6
|
||||
Revises: b51c6844d1df
|
||||
Create Date: 2026-02-12 20:29:47.448614
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c54986124c6"
|
||||
down_revision = "b51c6844d1df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scim_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("hashed_token", sa.String(length=64), nullable=False),
|
||||
sa.Column("token_display", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_by_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("hashed_token"),
|
||||
)
|
||||
op.create_table(
|
||||
"scim_group_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_group_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
"scim_group_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"scim_user_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
"scim_user_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
table_name="scim_user_mapping",
|
||||
)
|
||||
op.drop_table("scim_user_mapping")
|
||||
op.drop_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
table_name="scim_group_mapping",
|
||||
)
|
||||
op.drop_table("scim_group_mapping")
|
||||
op.drop_table("scim_token")
|
||||
@@ -1,81 +0,0 @@
|
||||
"""seed_memory_tool and add enable_memory_tool to user
|
||||
|
||||
Revision ID: b51c6844d1df
|
||||
Revises: 93c15d6a6fbb
|
||||
Create Date: 2026-02-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b51c6844d1df"
|
||||
down_revision = "93c15d6a6fbb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MEMORY_TOOL = {
|
||||
"name": "MemoryTool",
|
||||
"display_name": "Add Memory",
|
||||
"description": "Save memories about the user for future conversations.",
|
||||
"in_code_tool_id": "MemoryTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
existing = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id = :in_code_tool_id"
|
||||
),
|
||||
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
MEMORY_TOOL,
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
MEMORY_TOOL,
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"enable_memory_tool",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "enable_memory_tool")
|
||||
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""add sharing_scope to build_session
|
||||
|
||||
Revision ID: c7f2e1b4a9d3
|
||||
Revises: 19c0ccb01687
|
||||
Create Date: 2026-02-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c7f2e1b4a9d3"
|
||||
down_revision = "19c0ccb01687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"build_session",
|
||||
sa.Column(
|
||||
"sharing_scope",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("build_session", "sharing_scope")
|
||||
@@ -1,102 +0,0 @@
|
||||
"""add_file_reader_tool
|
||||
|
||||
Revision ID: d3fd499c829c
|
||||
Revises: 114a638452db
|
||||
Create Date: 2026-02-07 19:28:22.452337
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d3fd499c829c"
|
||||
down_revision = "114a638452db"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
FILE_READER_TOOL = {
|
||||
"name": "read_file",
|
||||
"display_name": "File Reader",
|
||||
"description": (
|
||||
"Read sections of user-uploaded files by character offset. "
|
||||
"Useful for inspecting large files that cannot fit entirely in context."
|
||||
),
|
||||
"in_code_tool_id": "FileReaderTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if tool already exists
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
FILE_READER_TOOL,
|
||||
)
|
||||
tool_id = existing[0]
|
||||
else:
|
||||
# Insert new tool
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
FILE_READER_TOOL,
|
||||
)
|
||||
tool_id = result.scalar_one()
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"]
|
||||
|
||||
# Remove persona associations first (FK constraint)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE tool_id IN (
|
||||
SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": in_code_tool_id},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": in_code_tool_id},
|
||||
)
|
||||
@@ -1,20 +1,20 @@
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.light import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -536,9 +536,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(
|
||||
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
|
||||
)
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
|
||||
# pass in the capability to fetch all existing docs for the cc_pair
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
@@ -578,13 +576,6 @@ def connector_permission_sync_generator_task(
|
||||
tasks_generated = 0
|
||||
docs_with_errors = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
f"Permission sync task timed out or stop signal detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
result = redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
@@ -941,7 +932,6 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
@@ -954,26 +944,11 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"PermissionSyncCallback - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
|
||||
@@ -466,7 +466,6 @@ def connector_external_group_sync_generator_task(
|
||||
def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
timeout_seconds: int = JOB_TIMEOUT,
|
||||
) -> None:
|
||||
# Create attempt record at the start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -519,23 +518,9 @@ def _perform_external_group_sync(
|
||||
seen_users: set[str] = set() # Track unique users across all groups
|
||||
total_groups_processed = 0
|
||||
total_group_memberships_synced = 0
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
raise RuntimeError(
|
||||
f"External group sync task timed out: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"elapsed={elapsed:.0f}s "
|
||||
f"timeout={timeout_seconds}s "
|
||||
f"groups_processed={total_groups_processed}"
|
||||
)
|
||||
|
||||
external_user_group_batch.append(external_user_group)
|
||||
|
||||
# Track progress
|
||||
|
||||
@@ -263,15 +263,9 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=source,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,604 +0,0 @@
|
||||
"""SCIM Data Access Layer.
|
||||
|
||||
All database operations for SCIM provisioning — token management, user
|
||||
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
|
||||
|
||||
Usage from FastAPI::
|
||||
|
||||
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
@router.post("/tokens")
|
||||
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
|
||||
token = dal.create_token(name=..., hashed_token=..., ...)
|
||||
dal.commit()
|
||||
return token
|
||||
|
||||
Usage from background tasks::
|
||||
|
||||
with ScimDAL.from_tenant("tenant_abc") as dal:
|
||||
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
|
||||
dal.commit()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import SQLColumnExpression
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimDAL(DAL):
|
||||
"""Data Access Layer for SCIM provisioning operations.
|
||||
|
||||
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
|
||||
when you want to persist changes. This follows the existing ``_no_commit``
|
||||
convention and lets callers batch multiple operations into one transaction.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
name: str,
|
||||
hashed_token: str,
|
||||
token_display: str,
|
||||
created_by_id: UUID,
|
||||
) -> ScimToken:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — this method automatically revokes
|
||||
all existing active tokens before creating the new one.
|
||||
"""
|
||||
# Revoke any currently active tokens
|
||||
active_tokens = list(
|
||||
self._session.scalars(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
).all()
|
||||
)
|
||||
for t in active_tokens:
|
||||
t.is_active = False
|
||||
|
||||
token = ScimToken(
|
||||
name=name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=created_by_id,
|
||||
)
|
||||
self._session.add(token)
|
||||
self._session.flush()
|
||||
return token
|
||||
|
||||
def get_active_token(self) -> ScimToken | None:
|
||||
"""Return the single currently active token, or None."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
)
|
||||
|
||||
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
|
||||
"""Look up a token by its SHA-256 hash."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
|
||||
)
|
||||
|
||||
def revoke_token(self, token_id: int) -> None:
|
||||
"""Deactivate a token by ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token does not exist.
|
||||
"""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if not token:
|
||||
raise ValueError(f"SCIM token with id {token_id} not found")
|
||||
token.is_active = False
|
||||
|
||||
def update_token_last_used(self, token_id: int) -> None:
|
||||
"""Update the last_used_at timestamp for a token."""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if token:
|
||||
token.last_used_at = func.now() # type: ignore[assignment]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_user_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the Onyx user ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
|
||||
)
|
||||
|
||||
def list_user_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimUserMapping], int]:
|
||||
"""List user mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimUserMapping)
|
||||
.order_by(ScimUserMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def update_user_mapping_external_id(
|
||||
self,
|
||||
mapping_id: int,
|
||||
external_id: str,
|
||||
) -> ScimUserMapping:
|
||||
"""Update the external ID on a user mapping.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mapping does not exist.
|
||||
"""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
|
||||
mapping.external_id = external_id
|
||||
return mapping
|
||||
|
||||
def delete_user_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a user mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_user(self, user_id: UUID) -> User | None:
|
||||
"""Fetch a user by ID."""
|
||||
return self._session.scalar(
|
||||
select(User).where(User.id == user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Fetch a user by email (case-insensitive)."""
|
||||
return self._session.scalar(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
|
||||
def add_user(self, user: User) -> None:
|
||||
"""Add a new user to the session and flush to assign an ID."""
|
||||
self._session.add(user)
|
||||
self._session.flush()
|
||||
|
||||
def update_user(
|
||||
self,
|
||||
user: User,
|
||||
*,
|
||||
email: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
personal_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update user attributes. Only sets fields that are provided."""
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if is_active is not None:
|
||||
user.is_active = is_active
|
||||
if personal_name is not None:
|
||||
user.personal_name = personal_name
|
||||
|
||||
def deactivate_user(self, user: User) -> None:
|
||||
"""Mark a user as inactive."""
|
||||
user.is_active = False
|
||||
|
||||
def list_users(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "username":
|
||||
# arg-type: fastapi-users types User.email as str, not a column expression
|
||||
# assignment: union return type widens but query is still Select[tuple[User]]
|
||||
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
|
||||
elif attr == "active":
|
||||
query = query.where(
|
||||
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
|
||||
)
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
# Count total matching rows first, then paginate. SCIM uses 1-based
|
||||
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
).all()
|
||||
)
|
||||
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_group_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_group_id: int,
|
||||
) -> ScimGroupMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user group."""
|
||||
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_group_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_group_mapping_by_group_id(
|
||||
self, user_group_id: int
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the Onyx user group ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
|
||||
def list_group_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimGroupMapping], int]:
|
||||
"""List group mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimGroupMapping)
|
||||
.order_by(ScimGroupMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def delete_group_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a group mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimGroupMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_group(self, group_id: int) -> UserGroup | None:
|
||||
"""Fetch a group by ID, returning None if deleted or missing."""
|
||||
group = self._session.get(UserGroup, group_id)
|
||||
if group and group.is_up_for_deletion:
|
||||
return None
|
||||
return group
|
||||
|
||||
def get_group_by_name(self, name: str) -> UserGroup | None:
|
||||
"""Fetch a group by exact name."""
|
||||
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
|
||||
|
||||
def add_group(self, group: UserGroup) -> None:
|
||||
"""Add a new group to the session and flush to assign an ID."""
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Update group attributes and set the modification timestamp."""
|
||||
if name is not None:
|
||||
group.name = name
|
||||
group.time_last_modified_by_user = func.now()
|
||||
|
||||
def delete_group(self, group: UserGroup) -> None:
|
||||
"""Delete a group from the session."""
|
||||
self._session.delete(group)
|
||||
|
||||
def list_groups(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[UserGroup, str | None]], int]:
|
||||
"""Query groups with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (group, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "displayname":
|
||||
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
|
||||
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(UserGroup.id == mapping.user_group_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
groups = list(
|
||||
self._session.scalars(
|
||||
query.order_by(UserGroup.id).offset(offset).limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
ext_id_map = self._get_group_external_ids([g.id for g in groups])
|
||||
return [(g, ext_id_map.get(g.id)) for g in groups], total
|
||||
|
||||
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
|
||||
"""Get group members as (user_id, email) pairs."""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
).all()
|
||||
|
||||
user_ids = [r.user_id for r in rels if r.user_id]
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
(
|
||||
r.user_id,
|
||||
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
|
||||
)
|
||||
for r in rels
|
||||
if r.user_id
|
||||
]
|
||||
|
||||
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
|
||||
"""Return the subset of UUIDs that don't exist as users.
|
||||
|
||||
Returns an empty list if all IDs are valid.
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Add user-group relationships, ignoring duplicates."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
pg_insert(User__UserGroup)
|
||||
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
User__UserGroup.user_group_id,
|
||||
User__UserGroup.user_id,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Replace all members of a group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
self.upsert_group_members(group_id, user_ids)
|
||||
|
||||
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Remove specific members from a group."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.in_(user_ids),
|
||||
)
|
||||
)
|
||||
|
||||
def delete_group_with_members(self, group: UserGroup) -> None:
|
||||
"""Remove all member relationships and delete the group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
|
||||
)
|
||||
self._session.delete(group)
|
||||
|
||||
def sync_group_external_id(
|
||||
self, group_id: int, new_external_id: str | None
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a group."""
|
||||
mapping = self.get_group_mapping_by_group_id(group_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_group_mapping(
|
||||
external_id=new_external_id, user_group_id=group_id
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_group_mapping(mapping.id)
|
||||
|
||||
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
|
||||
"""Batch-fetch external IDs for a list of group IDs."""
|
||||
if not group_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id.in_(group_ids)
|
||||
)
|
||||
).all()
|
||||
return {m.user_group_id: m.external_id for m in mappings}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (used by DAL methods above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_scim_string_op(
|
||||
query: Select[tuple[User]] | Select[tuple[UserGroup]],
|
||||
column: SQLColumnExpression[str],
|
||||
scim_filter: ScimFilter,
|
||||
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
|
||||
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
|
||||
|
||||
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
|
||||
SQLAlchemy's operators handle LIKE-pattern escaping internally.
|
||||
"""
|
||||
val = scim_filter.value
|
||||
if scim_filter.operator == ScimFilterOperator.EQUAL:
|
||||
return query.where(func.lower(column) == val.lower())
|
||||
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
|
||||
return query.where(column.icontains(val, autoescape=True))
|
||||
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
|
||||
return query.where(column.istartswith(val, autoescape=True))
|
||||
else:
|
||||
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -266,8 +209,6 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -275,16 +216,11 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
|
||||
@@ -65,7 +65,21 @@ def github_doc_sync(
|
||||
# Get all repositories from GitHub API
|
||||
logger.info("Fetching all repositories from GitHub API")
|
||||
try:
|
||||
repos = github_connector.fetch_configured_repos()
|
||||
repos = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(
|
||||
github_connector.github_client
|
||||
)
|
||||
else:
|
||||
# Single repository
|
||||
repos = [
|
||||
github_connector.get_github_repo(github_connector.github_client)
|
||||
]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
logger.info(f"Found {len(repos)} repositories to check")
|
||||
except Exception as e:
|
||||
|
||||
@@ -6,7 +6,6 @@ from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -47,27 +46,19 @@ def sharepoint_group_sync(
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
enumerate_all = connector_config.get(
|
||||
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
)
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
sp_domain_suffix = connector.sharepoint_domain_suffix
|
||||
# Process each site
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
# Create client context for the site using connector's MSAL app
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
|
||||
external_groups = get_sharepoint_external_groups(
|
||||
ctx,
|
||||
connector.graph_client,
|
||||
graph_api_base=connector.graph_api_base,
|
||||
get_access_token=connector._get_graph_access_token,
|
||||
enumerate_all_ad_groups=enumerate_all,
|
||||
)
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
|
||||
# Yield each group
|
||||
for group in external_groups:
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import re
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
@@ -18,10 +14,7 @@ from pydantic import BaseModel
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
|
||||
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
|
||||
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
|
||||
|
||||
|
||||
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
|
||||
|
||||
|
||||
def _graph_api_get(
|
||||
url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Authenticated Graph API GET with retry on transient errors."""
|
||||
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
|
||||
access_token = get_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
resp = _requests.get(
|
||||
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
)
|
||||
if (
|
||||
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
|
||||
and attempt < GRAPH_API_MAX_RETRIES
|
||||
):
|
||||
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
|
||||
logger.warning(
|
||||
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
f"Graph API connection error on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
|
||||
)
|
||||
|
||||
|
||||
def _iter_graph_collection(
|
||||
initial_url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Paginate through a Graph API collection, yielding items one at a time."""
|
||||
url: str | None = initial_url
|
||||
while url:
|
||||
data = _graph_api_get(url, get_access_token, params)
|
||||
params = None
|
||||
yield from data.get("value", [])
|
||||
url = data.get("@odata.nextLink")
|
||||
|
||||
|
||||
def _normalize_email(email: str) -> str:
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
return email.replace(MICROSOFT_DOMAIN, "")
|
||||
return email
|
||||
|
||||
|
||||
class SharepointGroup(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
|
||||
def _enumerate_ad_groups_paginated(
|
||||
get_access_token: Callable[[], str],
|
||||
already_resolved: set[str],
|
||||
graph_api_base: str,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
|
||||
|
||||
Skips groups whose suffixed name is already in *already_resolved*.
|
||||
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
|
||||
"""
|
||||
groups_url = f"{graph_api_base}/groups"
|
||||
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
|
||||
total_groups = 0
|
||||
|
||||
for group_json in _iter_graph_collection(
|
||||
groups_url, get_access_token, groups_params
|
||||
):
|
||||
group_id: str = group_json.get("id", "")
|
||||
display_name: str = group_json.get("displayName", "")
|
||||
if not group_id or not display_name:
|
||||
continue
|
||||
|
||||
total_groups += 1
|
||||
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
|
||||
"groups — stopping to avoid excessive memory/API usage. "
|
||||
"Remaining groups will be resolved from role assignments only."
|
||||
)
|
||||
return
|
||||
|
||||
name = f"{display_name}_{group_id}"
|
||||
if name in already_resolved:
|
||||
continue
|
||||
|
||||
member_emails: list[str] = []
|
||||
members_url = f"{graph_api_base}/groups/{group_id}/members"
|
||||
members_params: dict[str, str] = {
|
||||
"$select": "userPrincipalName,mail",
|
||||
"$top": "999",
|
||||
}
|
||||
for member_json in _iter_graph_collection(
|
||||
members_url, get_access_token, members_params
|
||||
):
|
||||
email = member_json.get("userPrincipalName") or member_json.get("mail")
|
||||
if email:
|
||||
member_emails.append(_normalize_email(email))
|
||||
|
||||
yield ExternalUserGroup(id=name, user_emails=member_emails)
|
||||
|
||||
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
|
||||
|
||||
|
||||
def get_sharepoint_external_groups(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
graph_api_base: str,
|
||||
get_access_token: Callable[[], str] | None = None,
|
||||
enumerate_all_ad_groups: bool = False,
|
||||
client_context: ClientContext, graph_client: GraphClient
|
||||
) -> list[ExternalUserGroup]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
|
||||
client_context, graph_client, groups, is_group_sync=True
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = [
|
||||
ExternalUserGroup(id=group_name, user_emails=list(emails))
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items()
|
||||
]
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
graph_client.groups.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups",
|
||||
)
|
||||
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
|
||||
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
|
||||
ad_groups_to_emails: dict[str, set[str]] = {}
|
||||
for group in azure_ad_groups:
|
||||
# If the group is already identified, we don't need to get the members
|
||||
if group.display_name in identified_groups:
|
||||
continue
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
name = group.display_name
|
||||
name = _get_group_name_with_suffix(group.id, name, graph_client)
|
||||
|
||||
if not enumerate_all_ad_groups or get_access_token is None:
|
||||
logger.info(
|
||||
"Skipping exhaustive Azure AD group enumeration. "
|
||||
"Only groups found in site role assignments are included."
|
||||
members = sleep_and_retry(
|
||||
group.members.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
|
||||
)
|
||||
return external_user_groups
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
if not ad_groups_to_emails.get(name):
|
||||
ad_groups_to_emails[name] = set()
|
||||
if user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
ad_groups_to_emails[name].add(user_principal_name)
|
||||
elif mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
ad_groups_to_emails[name].add(mail)
|
||||
|
||||
already_resolved = set(groups_and_members.groups_to_emails.keys())
|
||||
for group in _enumerate_ad_groups_paginated(
|
||||
get_access_token, already_resolved, graph_api_base
|
||||
):
|
||||
external_user_groups.append(group)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
for group_name, emails in ad_groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
return external_user_groups
|
||||
|
||||
@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
@@ -163,11 +162,6 @@ def get_application() -> FastAPI:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ def stream_search_query(
|
||||
# Get document index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
|
||||
@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
|
||||
|
||||
|
||||
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
|
||||
# before bearer token configuration is complete
|
||||
("/scim/v2/ServiceProviderConfig", {"GET"}),
|
||||
("/scim/v2/ResourceTypes", {"GET"}),
|
||||
("/scim/v2/Schemas", {"GET"}),
|
||||
# needs to be accessible prior to user login
|
||||
("/enterprise-settings", {"GET"}),
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
|
||||
@@ -109,9 +109,7 @@ async def _make_billing_request(
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT, follow_redirects=True
|
||||
) as client:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
|
||||
@@ -13,7 +13,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import store_settings
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
|
||||
from ee.onyx.server.scim.models import ScimTokenResponse
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
@@ -203,63 +198,3 @@ def upload_custom_analytics_script(
|
||||
@basic_router.get("/custom-analytics-script")
|
||||
def fetch_custom_analytics_script() -> str | None:
|
||||
return load_analytics_script()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM token management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
@admin_router.get("/scim/token")
|
||||
def get_active_scim_token(
|
||||
_: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenResponse:
|
||||
"""Return the currently active SCIM token's metadata, or 404 if none."""
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/scim/token", status_code=201)
|
||||
def create_scim_token(
|
||||
body: ScimTokenCreate,
|
||||
user: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenCreatedResponse:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — creating a new token automatically
|
||||
revokes all previous tokens. The raw token value is returned exactly once
|
||||
in the response; it cannot be retrieved again.
|
||||
"""
|
||||
raw_token, hashed_token, token_display = generate_scim_token()
|
||||
token = dal.create_token(
|
||||
name=body.name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
dal.commit()
|
||||
|
||||
return ScimTokenCreatedResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
raw_token=raw_token,
|
||||
)
|
||||
|
||||
@@ -27,8 +27,6 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -67,13 +66,7 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_vector_db)],
|
||||
)
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User = Depends(current_user),
|
||||
|
||||
@@ -1,689 +0,0 @@
|
||||
"""SCIM 2.0 API endpoints (RFC 7644).
|
||||
|
||||
This module provides the FastAPI router for SCIM service discovery,
|
||||
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
|
||||
these endpoints to provision and manage users and groups.
|
||||
|
||||
Service discovery endpoints are unauthenticated — IdPs may probe them
|
||||
before bearer token configuration is complete. All other endpoints
|
||||
require a valid SCIM bearer token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/ServiceProviderConfig")
|
||||
def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
"""Advertise supported SCIM features (RFC 7643 §5)."""
|
||||
return SERVICE_PROVIDER_CONFIG
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return JSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
user = dal.get_user(uid)
|
||||
if not user:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
|
||||
def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""Extract a display name string from a SCIM name object.
|
||||
|
||||
Returns None if no name is provided, so the caller can decide
|
||||
whether to update the user's personal_name.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip().lower()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
try:
|
||||
dal.add_user(user)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip().lower(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, new_external_id)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
``PATCH {"active": false}`` rather than DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current = _user_to_scim(user, external_id)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
dal.deactivate_user(user)
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
group = dal.get_group(gid)
|
||||
if not group:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
return group
|
||||
|
||||
|
||||
def _parse_member_uuids(
|
||||
members: list[ScimGroupMember],
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse member value strings to UUIDs.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids: list[UUID] = []
|
||||
for m in members:
|
||||
try:
|
||||
uuids.append(UUID(m.value))
|
||||
except ValueError:
|
||||
return [], f"Invalid member ID: {m.value}"
|
||||
return uuids, None
|
||||
|
||||
|
||||
def _validate_and_parse_members(
|
||||
members: list[ScimGroupMember], dal: ScimDAL
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse and validate member UUIDs exist in the database.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids, err = _parse_member_uuids(members)
|
||||
if err:
|
||||
return [], err
|
||||
|
||||
if uuids:
|
||||
missing = dal.validate_member_ids(uuids)
|
||||
if missing:
|
||||
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
|
||||
|
||||
return uuids, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
db_group = UserGroup(
|
||||
name=group_resource.displayName,
|
||||
is_up_to_date=True,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
try:
|
||||
dal.add_group(db_group)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
missing = dal.validate_member_ids(add_uuids)
|
||||
if missing:
|
||||
return _scim_error_response(
|
||||
400,
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
def _is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
try:
|
||||
UUID(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -1,104 +0,0 @@
|
||||
"""SCIM bearer token authentication.
|
||||
|
||||
SCIM endpoints are authenticated via bearer tokens that admins create in the
|
||||
Onyx UI. This module provides:
|
||||
|
||||
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
|
||||
validates the token from the Authorization header.
|
||||
- ``generate_scim_token``: Creates a new cryptographically random token
|
||||
and returns the raw value, its SHA-256 hash, and a display suffix.
|
||||
|
||||
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
|
||||
URL-safe base64 from ``secrets.token_urlsafe``.
|
||||
|
||||
The hash is stored in the ``scim_token`` table; the raw value is shown to
|
||||
the admin exactly once at creation time.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
|
||||
def _hash_scim_token(token: str) -> str:
|
||||
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def generate_scim_token() -> tuple[str, str, str]:
|
||||
"""Generate a new SCIM bearer token.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(raw_token, hashed_token, token_display)`` where
|
||||
``token_display`` is a masked version showing only the last 4 chars.
|
||||
"""
|
||||
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
|
||||
hashed_token = _hash_scim_token(raw_token)
|
||||
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
|
||||
return raw_token, hashed_token, token_display
|
||||
|
||||
|
||||
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
|
||||
"""Extract and hash a SCIM token from the request Authorization header."""
|
||||
return get_hashed_bearer_token_from_request(
|
||||
request,
|
||||
valid_prefixes=[SCIM_TOKEN_PREFIX],
|
||||
hash_fn=_hash_scim_token,
|
||||
)
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
def verify_scim_token(
|
||||
request: Request,
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimToken:
|
||||
"""FastAPI dependency that authenticates SCIM requests.
|
||||
|
||||
Extracts the bearer token from the Authorization header, hashes it,
|
||||
looks it up in the database, and verifies it is active.
|
||||
|
||||
Note:
|
||||
This dependency does NOT update ``last_used_at`` — the endpoint
|
||||
should do that via ``ScimDAL.update_token_last_used()`` so the
|
||||
timestamp write is part of the endpoint's transaction.
|
||||
|
||||
Raises:
|
||||
HTTPException(401): If the token is missing, invalid, or inactive.
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
|
||||
return token
|
||||
@@ -1,96 +0,0 @@
|
||||
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
|
||||
|
||||
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
|
||||
resources before deciding whether to create or update them. For example, when
|
||||
an admin assigns a user to the Onyx app, the IdP first checks whether that
|
||||
user already exists::
|
||||
|
||||
GET /scim/v2/Users?filter=userName eq "john@example.com"
|
||||
|
||||
If zero results come back the IdP creates the user (``POST``); if a match is
|
||||
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
|
||||
The same pattern applies to groups (``displayName eq "Engineering"``).
|
||||
|
||||
This module parses the subset of the SCIM filter grammar that identity
|
||||
providers actually send in practice:
|
||||
|
||||
attribute SP operator SP value
|
||||
|
||||
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
|
||||
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
|
||||
the parser returns ``None`` and the caller falls back to an unfiltered list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScimFilterOperator(str, Enum):
|
||||
"""Supported SCIM filter operators."""
|
||||
|
||||
EQUAL = "eq"
|
||||
CONTAINS = "co"
|
||||
STARTS_WITH = "sw"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScimFilter:
|
||||
"""Parsed SCIM filter expression."""
|
||||
|
||||
attribute: str
|
||||
operator: ScimFilterOperator
|
||||
value: str
|
||||
|
||||
|
||||
# Matches: attribute operator "value" (with or without quotes around value)
|
||||
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
|
||||
_FILTER_RE = re.compile(
|
||||
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
|
||||
r'(?:"([^"]*)"' # quoted value
|
||||
r"|'([^']*)')" # or single-quoted value
|
||||
r"$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
|
||||
"""Parse a simple SCIM filter expression.
|
||||
|
||||
Args:
|
||||
filter_string: Raw filter query parameter value, e.g.
|
||||
``'userName eq "john@example.com"'``
|
||||
|
||||
Returns:
|
||||
A ``ScimFilter`` if the expression is valid and uses a supported
|
||||
operator, or ``None`` if the input is empty / missing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter string is present but malformed or uses
|
||||
an unsupported operator.
|
||||
"""
|
||||
if not filter_string or not filter_string.strip():
|
||||
return None
|
||||
|
||||
match = _FILTER_RE.match(filter_string.strip())
|
||||
if not match:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
|
||||
|
||||
return _build_filter(match, filter_string)
|
||||
|
||||
|
||||
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
|
||||
"""Extract fields from a regex match and construct a ScimFilter."""
|
||||
attribute = match.group(1)
|
||||
op_str = match.group(2).lower()
|
||||
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
|
||||
value = match.group(3) if match.group(3) is not None else match.group(4)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
|
||||
|
||||
operator = ScimFilterOperator(op_str)
|
||||
|
||||
return ScimFilter(attribute=attribute, operator=operator, value=value)
|
||||
@@ -1,285 +0,0 @@
|
||||
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
|
||||
|
||||
SCIM protocol schemas follow the wire format defined in:
|
||||
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
|
||||
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
|
||||
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Schema URIs (RFC 7643 §8)
|
||||
# Every SCIM JSON payload includes a "schemas" array identifying its type.
|
||||
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
|
||||
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Protocol Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimName(BaseModel):
|
||||
"""User name components (RFC 7643 §4.1.1)."""
|
||||
|
||||
givenName: str | None = None
|
||||
familyName: str | None = None
|
||||
formatted: str | None = None
|
||||
|
||||
|
||||
class ScimEmail(BaseModel):
|
||||
"""Email sub-attribute (RFC 7643 §4.1.2)."""
|
||||
|
||||
value: str
|
||||
type: str | None = None
|
||||
primary: bool = False
|
||||
|
||||
|
||||
class ScimMeta(BaseModel):
|
||||
"""Resource metadata (RFC 7643 §3.1)."""
|
||||
|
||||
resourceType: str | None = None
|
||||
created: datetime | None = None
|
||||
lastModified: datetime | None = None
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
This is the JSON shape that IdPs send when creating/updating a user via
|
||||
SCIM, and the shape we return in GET responses. Field names use camelCase
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
"""Group member reference (RFC 7643 §4.2).
|
||||
|
||||
Represents a user within a SCIM group. The IdP sends these when adding
|
||||
or removing users from groups. ``value`` is the Onyx user ID.
|
||||
"""
|
||||
|
||||
value: str # User ID of the group member
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimGroupResource(BaseModel):
|
||||
"""SCIM Group resource representation (RFC 7643 §4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
|
||||
id: str | None = None
|
||||
externalId: str | None = None
|
||||
displayName: str
|
||||
members: list[ScimGroupMember] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimListResponse(BaseModel):
|
||||
"""Paginated list response (RFC 7644 §3.4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
|
||||
totalResults: int
|
||||
startIndex: int = 1
|
||||
itemsPerPage: int = 100
|
||||
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimPatchOperationType(str, Enum):
|
||||
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
|
||||
|
||||
ADD = "add"
|
||||
REPLACE = "replace"
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
IdPs use PATCH to make incremental changes — e.g. deactivating a user
|
||||
(replace active=false) or adding/removing group members — instead of
|
||||
replacing the entire resource with PUT.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
|
||||
Operations: list[ScimPatchOperation]
|
||||
|
||||
|
||||
class ScimError(BaseModel):
|
||||
"""SCIM error response (RFC 7644 §3.12)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
|
||||
status: str
|
||||
detail: str | None = None
|
||||
scimType: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Provider Configuration (RFC 7643 §5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimSupported(BaseModel):
|
||||
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
|
||||
|
||||
supported: bool
|
||||
|
||||
|
||||
class ScimFilterConfig(BaseModel):
|
||||
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
|
||||
|
||||
supported: bool
|
||||
maxResults: int = 100
|
||||
|
||||
|
||||
class ScimServiceProviderConfig(BaseModel):
|
||||
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
|
||||
|
||||
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
|
||||
initial setup to discover which SCIM features our server supports
|
||||
(e.g. PATCH yes, bulk no, filtering yes).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(
|
||||
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
|
||||
)
|
||||
patch: ScimSupported = ScimSupported(supported=True)
|
||||
bulk: ScimSupported = ScimSupported(supported=False)
|
||||
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
|
||||
changePassword: ScimSupported = ScimSupported(supported=False)
|
||||
sort: ScimSupported = ScimSupported(supported=False)
|
||||
etag: ScimSupported = ScimSupported(supported=False)
|
||||
authenticationSchemes: list[dict[str, str]] = Field(
|
||||
default_factory=lambda: [
|
||||
{
|
||||
"type": "oauthbearertoken",
|
||||
"name": "OAuth Bearer Token",
|
||||
"description": "Authentication scheme using a SCIM bearer token",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaAttribute(BaseModel):
|
||||
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
multiValued: bool = False
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
caseExact: bool = False
|
||||
mutability: str = "readWrite"
|
||||
returned: str = "default"
|
||||
uniqueness: str = "none"
|
||||
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaDefinition(BaseModel):
|
||||
"""SCIM Schema definition (RFC 7643 §7).
|
||||
|
||||
Served at GET /scim/v2/Schemas. Describes the attributes available
|
||||
on each resource type so IdPs know which fields they can provision.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
|
||||
|
||||
class ScimResourceType(BaseModel):
|
||||
"""SCIM ResourceType resource (RFC 7643 §6).
|
||||
|
||||
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
endpoint: str
|
||||
description: str | None = None
|
||||
schema_: str = Field(alias="schema")
|
||||
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin API Schemas (Onyx-internal, for SCIM token management)
|
||||
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
|
||||
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimTokenCreate(BaseModel):
|
||||
"""Request to create a new SCIM bearer token."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ScimTokenResponse(BaseModel):
|
||||
"""SCIM token metadata returned in list/get responses."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
"""Response returned when a new SCIM token is created.
|
||||
|
||||
Includes the raw token value which is only available at creation time.
|
||||
"""
|
||||
|
||||
raw_token: str
|
||||
@@ -1,256 +0,0 @@
|
||||
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
|
||||
|
||||
Identity providers use PATCH to make incremental changes to SCIM resources
|
||||
instead of replacing the entire resource with PUT. Common operations include:
|
||||
|
||||
- Deactivating a user: ``replace`` ``active`` with ``false``
|
||||
- Adding group members: ``add`` to ``members``
|
||||
- Removing group members: ``remove`` from ``members[value eq "..."]``
|
||||
|
||||
This module applies PATCH operations to Pydantic SCIM resource objects and
|
||||
returns the modified result. It does NOT touch the database — the caller is
|
||||
responsible for persisting changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
|
||||
def __init__(self, detail: str, status: int = 400) -> None:
|
||||
self.detail = detail
|
||||
self.status = status
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
current_members: list[dict] = list(data.get("members") or [])
|
||||
added_ids: list[str] = []
|
||||
removed_ids: list[str] = []
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_group_remove(op, current_members, removed_ids)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on Group resource"
|
||||
)
|
||||
|
||||
data["members"] = current_members
|
||||
group = ScimGroupResource.model_validate(data)
|
||||
return group, added_ids, removed_ids
|
||||
|
||||
|
||||
def _apply_group_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
removed_ids.extend(old_ids - new_ids)
|
||||
added_ids.extend(new_ids - old_ids)
|
||||
|
||||
current_members[:] = value
|
||||
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
added_ids: list[str],
|
||||
) -> None:
|
||||
"""Add members to a group."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if path and path != "members":
|
||||
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
|
||||
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
added_ids.append(member_id)
|
||||
existing_ids.add(member_id)
|
||||
|
||||
|
||||
def _apply_group_remove(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove members from a group."""
|
||||
if not op.path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
match = _MEMBER_FILTER_RE.match(op.path)
|
||||
if not match:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported remove path '{op.path}'. "
|
||||
'Expected: members[value eq "user-id"]'
|
||||
)
|
||||
|
||||
target_id = match.group(1)
|
||||
original_len = len(members)
|
||||
members[:] = [m for m in members if m.get("value") != target_id]
|
||||
|
||||
if len(members) < original_len:
|
||||
removed_ids.append(target_id)
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
|
||||
|
||||
Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaAttribute
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
|
||||
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
|
||||
|
||||
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/scim/v2/Groups",
|
||||
"description": "SCIM Group resource",
|
||||
"schema": SCIM_GROUP_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_USER_SCHEMA,
|
||||
name="User",
|
||||
description="SCIM core User schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="userName",
|
||||
type="string",
|
||||
required=True,
|
||||
uniqueness="server",
|
||||
description="Unique identifier for the user, typically an email address.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="name",
|
||||
type="complex",
|
||||
description="The components of the user's name.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="givenName",
|
||||
type="string",
|
||||
description="The user's first name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="familyName",
|
||||
type="string",
|
||||
description="The user's last name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="formatted",
|
||||
type="string",
|
||||
description="The full name, including all middle names and titles.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="emails",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Email addresses for the user.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Email address value.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="type",
|
||||
type="string",
|
||||
description="Label for this email (e.g. 'work').",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="primary",
|
||||
type="boolean",
|
||||
description="Whether this is the primary email.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="active",
|
||||
type="boolean",
|
||||
description="Whether the user account is active.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
description="SCIM core Group schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="displayName",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Human-readable name for the group.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="members",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Members of the group.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="User ID of the group member.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="display",
|
||||
type="string",
|
||||
mutability="readOnly",
|
||||
description="Display name of the group member.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1,13 +1,9 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -44,14 +40,6 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -93,18 +81,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -113,11 +89,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license found in cache or DB.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
# No license = community edition, disable EE features
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
@@ -177,7 +177,7 @@ async def forward_to_control_plane(
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
|
||||
@@ -37,15 +37,12 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
@@ -43,21 +41,8 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
captcha_token: str | None = None
|
||||
|
||||
@override
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
@@ -60,7 +60,6 @@ from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
@@ -111,7 +110,6 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -121,7 +119,6 @@ from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -138,8 +135,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
return user.role == UserRole.ADMIN
|
||||
@@ -211,34 +206,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def workspace_invite_only_enabled() -> bool:
|
||||
settings = load_settings()
|
||||
return settings.invite_only_enabled
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
|
||||
# SSO providers manage membership; allow JIT provisioning regardless of invites
|
||||
return
|
||||
|
||||
if not workspace_invite_only_enabled():
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -255,13 +238,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
)
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
@@ -295,22 +272,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
"""Raise HTTPException(402) if adding users would exceed the seat limit.
|
||||
|
||||
No-op for multi-tenant or CE deployments.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=seats_needed)
|
||||
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -440,12 +401,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
# Check seat availability for new users (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
existing = get_user_by_email(user_create.email, sync_db)
|
||||
if existing is None:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
@@ -655,10 +610,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
# Check seat availability before creating (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
@@ -1480,7 +1431,6 @@ def get_anonymous_user() -> User:
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -1671,10 +1621,7 @@ def get_oauth_router(
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
# Use WEB_DOMAIN instead of request.url_for() to prevent host
|
||||
# header poisoning — request.url_for() trusts the Host header.
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -526,12 +525,6 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
logger.info(
|
||||
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
|
||||
)
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
@@ -573,31 +566,3 @@ class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
def get_bootsteps() -> list[type]:
|
||||
return [LivenessProbe]
|
||||
|
||||
|
||||
# Task modules that require a vector DB (Vespa/OpenSearch).
|
||||
# When DISABLE_VECTOR_DB is True these are excluded from autodiscover lists.
|
||||
_VECTOR_DB_TASK_MODULES: set[str] = {
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
# EE modules that are vector-DB-dependent
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
}
|
||||
# NOTE: "onyx.background.celery.tasks.shared" is intentionally NOT in the set
|
||||
# above. It contains celery_beat_heartbeat (which only writes to Redis) alongside
|
||||
# document cleanup tasks. The cleanup tasks won't be invoked in minimal mode
|
||||
# because the periodic tasks that trigger them are in other filtered modules.
|
||||
|
||||
|
||||
def filter_task_modules(modules: list[str]) -> list[str]:
|
||||
"""Remove vector-DB-dependent task modules when DISABLE_VECTOR_DB is True."""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return modules
|
||||
return [m for m in modules if m not in _VECTOR_DB_TASK_MODULES]
|
||||
|
||||
@@ -118,25 +118,23 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -96,9 +96,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -107,9 +107,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -96,12 +96,10 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,16 +110,14 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -94,9 +94,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -314,18 +314,16 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -107,9 +107,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,30 +1,25 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.connector_runner import batched_doc_ids
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -34,129 +29,63 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
|
||||
doc_ids: set[str]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _checkpointed_batched_items(
|
||||
connector: CheckpointedConnector[CT],
|
||||
start: float,
|
||||
end: float,
|
||||
) -> Generator[list[Document | HierarchyNode | ConnectorFailure], None, None]:
|
||||
"""Loop through all checkpoint steps and yield batched items.
|
||||
|
||||
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
|
||||
checkpoint call may only initialize internal state without yielding
|
||||
any documents. This function loops until checkpoint.has_more is False
|
||||
to ensure all items are collected across every step.
|
||||
"""
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
while True:
|
||||
checkpoint_output = connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
|
||||
batch: list[Document | HierarchyNode | ConnectorFailure] = []
|
||||
for document, hierarchy_node, failure, next_checkpoint in wrapper(
|
||||
checkpoint_output
|
||||
):
|
||||
if document is not None:
|
||||
batch.append(document)
|
||||
elif hierarchy_node is not None:
|
||||
batch.append(hierarchy_node)
|
||||
elif failure is not None:
|
||||
batch.append(failure)
|
||||
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
if not checkpoint.has_more:
|
||||
break
|
||||
|
||||
|
||||
def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
"""Extract the document/entity ID from a ConnectorFailure."""
|
||||
if failure.failed_document:
|
||||
return failure.failed_document.document_id
|
||||
if failure.failed_entity:
|
||||
return failure.failed_entity.entity_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids.add(failed_id)
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
def document_batch_to_ids(
|
||||
doc_batch: (
|
||||
Iterator[list[Document | HierarchyNode]]
|
||||
| Iterator[list[SlimDocument | HierarchyNode]]
|
||||
),
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {
|
||||
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
|
||||
for doc in doc_list
|
||||
}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> SlimConnectorExtractionResult:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Extract document IDs and hierarchy nodes from a runnable connector.
|
||||
|
||||
Hierarchy nodes yielded alongside documents/slim docs are collected and
|
||||
returned in the result. ConnectorFailure items have their IDs preserved
|
||||
so that failed-to-retrieve documents are not accidentally pruned.
|
||||
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
raw_batch_generator: (
|
||||
Iterator[Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure]]
|
||||
| None
|
||||
) = None
|
||||
|
||||
doc_batch_id_generator = None
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
raw_batch_generator = runnable_connector.retrieve_all_slim_docs()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs()
|
||||
)
|
||||
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
|
||||
raw_batch_generator = runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
)
|
||||
# If the connector isn't slim, fall back to running it normally to get ids
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
raw_batch_generator = runnable_connector.load_from_state()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
raw_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.poll_source(start=start, end=end)
|
||||
)
|
||||
elif isinstance(runnable_connector, CheckpointedConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
raw_batch_generator = _checkpointed_batched_items(
|
||||
runnable_connector, start, end
|
||||
checkpoint = runnable_connector.build_dummy_checkpoint()
|
||||
checkpoint_generator = runnable_connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
doc_batch_id_generator = batched_doc_ids(
|
||||
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
@@ -170,24 +99,19 @@ def extract_ids_from_runnable_connector(
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
# process raw batches to extract both IDs and hierarchy nodes
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
doc_ids=all_connector_doc_ids,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
@@ -230,27 +229,6 @@ if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
)
|
||||
|
||||
|
||||
# Beat task names that require a vector DB. Filtered out when DISABLE_VECTOR_DB.
|
||||
_VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-indexing",
|
||||
"check-for-connector-deletion",
|
||||
"check-for-vespa-sync",
|
||||
"check-for-pruning",
|
||||
"check-for-hierarchy-fetching",
|
||||
"check-for-checkpoint-cleanup",
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
beat_task_templates = [
|
||||
t for t in beat_task_templates if t["name"] not in _VECTOR_DB_BEAT_TASK_NAMES
|
||||
]
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task: dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
@@ -52,29 +51,11 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
if bool(self.redis_connector.stop.fenced):
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"IndexingCallback Docprocessing - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
"""Amount isn't used yet."""
|
||||
|
||||
@@ -146,26 +146,14 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
|
||||
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
|
||||
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
"celery_queue_length": "celery",
|
||||
"docprocessing_queue_length": "docprocessing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
|
||||
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
|
||||
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
|
||||
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
|
||||
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
|
||||
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
|
||||
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
@@ -893,7 +881,7 @@ def monitor_celery_queues_helper(
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
@@ -920,26 +908,6 @@ def monitor_celery_queues_helper(
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
n_hierarchy_fetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
|
||||
)
|
||||
n_llm_model_update = celery_get_queue_length(
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
|
||||
)
|
||||
n_checkpoint_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
|
||||
)
|
||||
n_index_attempt_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
|
||||
)
|
||||
n_csv_generation = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CSV_GENERATION, r_celery
|
||||
)
|
||||
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
|
||||
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
|
||||
n_opensearch_migration = celery_get_queue_length(
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
|
||||
)
|
||||
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -963,14 +931,6 @@ def monitor_celery_queues_helper(
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
f"hierarchy_fetching={n_hierarchy_fetching} "
|
||||
f"llm_model_update={n_llm_model_update} "
|
||||
f"checkpoint_cleanup={n_checkpoint_cleanup} "
|
||||
f"index_attempt_cleanup={n_index_attempt_cleanup} "
|
||||
f"csv_generation={n_csv_generation} "
|
||||
f"monitoring={n_monitoring} "
|
||||
f"sandbox={n_sandbox} "
|
||||
f"opensearch_migration={n_opensearch_migration} "
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,14 +41,3 @@ assert (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
|
||||
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
|
||||
# WARNING: Do not change these values without knowing what changes also need to
|
||||
# be made to OpenSearchTenantMigrationRecord.
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 500
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT = 4
|
||||
|
||||
# String used to indicate in the vespa_visit_continuation_token mapping that the
|
||||
# slice has finished and there is nothing left to visit.
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
|
||||
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
|
||||
)
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
"""Celery tasks for migrating documents from Vespa to OpenSearch."""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
@@ -47,19 +40,14 @@ from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
def is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map: dict[int, str | None],
|
||||
) -> bool:
|
||||
return all(
|
||||
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
|
||||
for continuation_token in continuation_token_map.values()
|
||||
)
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@@ -88,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
|
||||
per-document), transform them, and index them into OpenSearch. Progress is
|
||||
tracked via a continuation token map stored in the
|
||||
tracked via a continuation token stored in the
|
||||
OpenSearchTenantMigrationRecord.
|
||||
|
||||
The first time we see no continuation token map and non-zero chunks
|
||||
migrated, we consider the migration complete and all subsequent invocations
|
||||
are no-ops.
|
||||
|
||||
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
|
||||
where progress is tracked for each slice.
|
||||
The first time we see no continuation token and non-zero chunks migrated, we
|
||||
consider the migration complete and all subsequent invocations are no-ops.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -134,7 +118,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
)
|
||||
|
||||
total_chunks_migrated_this_task = 0
|
||||
total_chunks_errored_this_task = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
@@ -169,52 +152,39 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token_map,
|
||||
continuation_token,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
if continuation_token is None and total_chunks_migrated > 0:
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
f"Continuation token: {continuation_token}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
raw_vespa_chunks, next_continuation_token = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
continuation_token_map=continuation_token_map,
|
||||
continuation_token=continuation_token,
|
||||
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
f"seconds. Next continuation token: {next_continuation_token}"
|
||||
)
|
||||
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
opensearch_document_chunks: list[DocumentChunk] = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
raw_vespa_chunks,
|
||||
tenant_state,
|
||||
@@ -222,10 +192,9 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
)
|
||||
)
|
||||
if len(opensearch_document_chunks) != len(raw_vespa_chunks):
|
||||
task_logger.error(
|
||||
f"Migration task error: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does "
|
||||
f"not match number of chunks in Vespa ({len(raw_vespa_chunks)}). {len(errored_chunks)} chunks "
|
||||
"errored."
|
||||
raise RuntimeError(
|
||||
f"Bug: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does not match "
|
||||
f"number of chunks in Vespa ({len(raw_vespa_chunks)})."
|
||||
)
|
||||
|
||||
index_opensearch_chunks_start_time = time.monotonic()
|
||||
@@ -238,16 +207,16 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
)
|
||||
|
||||
total_chunks_migrated_this_task += len(opensearch_document_chunks)
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
continuation_token=next_continuation_token,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
|
||||
task_logger.info("Vespa reported no more chunks to migrate.")
|
||||
break
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
return False
|
||||
finally:
|
||||
@@ -261,7 +230,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.info(
|
||||
f"OpenSearch chunk migration task pausing (time limit reached). "
|
||||
f"Total chunks migrated this task: {total_chunks_migrated_this_task}. "
|
||||
f"Total chunks errored this task: {total_chunks_errored_this_task}. "
|
||||
f"Elapsed: {time.monotonic() - task_start_time:.3f}s. "
|
||||
"Will resume from continuation token on next invocation."
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -37,35 +36,6 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
DOCUMENT_ID,
|
||||
CHUNK_ID,
|
||||
TITLE,
|
||||
TITLE_EMBEDDING,
|
||||
CONTENT,
|
||||
EMBEDDINGS,
|
||||
SOURCE_TYPE,
|
||||
METADATA_LIST,
|
||||
DOC_UPDATED_AT,
|
||||
HIDDEN,
|
||||
BOOST,
|
||||
SEMANTIC_IDENTIFIER,
|
||||
IMAGE_FILE_NAME,
|
||||
SOURCE_LINKS,
|
||||
BLURB,
|
||||
DOC_SUMMARY,
|
||||
CHUNK_CONTEXT,
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
]
|
||||
if MULTI_TENANT:
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
|
||||
|
||||
|
||||
def _extract_content_vector(embeddings: Any) -> list[float]:
|
||||
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.
|
||||
|
||||
@@ -170,7 +140,9 @@ def _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_acl: dict[str, int] | None,
|
||||
) -> tuple[bool, list[str]]:
|
||||
if not vespa_acl:
|
||||
return False, []
|
||||
raise ValueError(
|
||||
"Missing ACL in Vespa chunk. This does not make sense as it implies the document is never searchable by anyone ever."
|
||||
)
|
||||
acl_list = list(vespa_acl.keys())
|
||||
is_public = PUBLIC_DOC_PAT in acl_list
|
||||
if is_public:
|
||||
@@ -182,162 +154,136 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
vespa_chunks: list[dict[str, Any]],
|
||||
tenant_state: TenantState,
|
||||
sanitized_to_original_doc_id_mapping: dict[str, str],
|
||||
) -> tuple[list[DocumentChunk], list[dict[str, Any]]]:
|
||||
) -> list[DocumentChunk]:
|
||||
result: list[DocumentChunk] = []
|
||||
errored_chunks: list[dict[str, Any]] = []
|
||||
for vespa_chunk in vespa_chunks:
|
||||
try:
|
||||
# This should exist; fail loudly if it does not.
|
||||
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
|
||||
if not vespa_document_id:
|
||||
raise ValueError("Missing document_id in Vespa chunk.")
|
||||
# Vespa doc IDs were sanitized using
|
||||
# replace_invalid_doc_id_characters. This was a poor design choice
|
||||
# and we don't want this in OpenSearch; whatever restrictions there
|
||||
# may be on indexed chunk ID should have no bearing on the chunk's
|
||||
# document ID field, even if document ID is an argument to the chunk
|
||||
# ID. Deliberately choose to use the real doc ID supplied to this
|
||||
# function.
|
||||
if vespa_document_id in sanitized_to_original_doc_id_mapping:
|
||||
logger.warning(
|
||||
f"Migration warning: Vespa document ID {vespa_document_id} does not match the document ID supplied "
|
||||
f"{sanitized_to_original_doc_id_mapping[vespa_document_id]}. "
|
||||
"The Vespa ID will be discarded."
|
||||
)
|
||||
document_id = sanitized_to_original_doc_id_mapping.get(
|
||||
vespa_document_id, vespa_document_id
|
||||
# This should exist; fail loudly if it does not.
|
||||
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
|
||||
if not vespa_document_id:
|
||||
raise ValueError("Missing document_id in Vespa chunk.")
|
||||
# Vespa doc IDs were sanitized using replace_invalid_doc_id_characters.
|
||||
# This was a poor design choice and we don't want this in OpenSearch;
|
||||
# whatever restrictions there may be on indexed chunk ID should have no
|
||||
# bearing on the chunk's document ID field, even if document ID is an
|
||||
# argument to the chunk ID. Deliberately choose to use the real doc ID
|
||||
# supplied to this function.
|
||||
if vespa_document_id in sanitized_to_original_doc_id_mapping:
|
||||
logger.warning(
|
||||
f"Vespa document ID {vespa_document_id} does not match the document ID supplied "
|
||||
f"{sanitized_to_original_doc_id_mapping[vespa_document_id]}. "
|
||||
"The Vespa ID will be discarded."
|
||||
)
|
||||
document_id = sanitized_to_original_doc_id_mapping.get(
|
||||
vespa_document_id, vespa_document_id
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
chunk_index: int = vespa_chunk[CHUNK_ID]
|
||||
# This should exist; fail loudly if it does not.
|
||||
chunk_index: int = vespa_chunk[CHUNK_ID]
|
||||
|
||||
title: str | None = vespa_chunk.get(TITLE)
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa
|
||||
# client in order to get a supported format for the tensors.
|
||||
title_vector: list[float] | None = _extract_title_vector(
|
||||
vespa_chunk.get(TITLE_EMBEDDING)
|
||||
title: str | None = vespa_chunk.get(TITLE)
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa client
|
||||
# in order to get a supported format for the tensors.
|
||||
title_vector: list[float] | None = _extract_title_vector(
|
||||
vespa_chunk.get(TITLE_EMBEDDING)
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
content: str = vespa_chunk[CONTENT]
|
||||
if not content:
|
||||
raise ValueError("Missing content in Vespa chunk.")
|
||||
# This should exist; fail loudly if it does not.
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa client
|
||||
# in order to get a supported format for the tensors.
|
||||
content_vector: list[float] = _extract_content_vector(vespa_chunk[EMBEDDINGS])
|
||||
if not content_vector:
|
||||
raise ValueError("Missing content_vector in Vespa chunk.")
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
source_type: str = vespa_chunk[SOURCE_TYPE]
|
||||
if not source_type:
|
||||
raise ValueError("Missing source_type in Vespa chunk.")
|
||||
|
||||
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
|
||||
|
||||
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
|
||||
last_updated: datetime | None = (
|
||||
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
|
||||
if _raw_doc_updated_at is not None
|
||||
else None
|
||||
)
|
||||
|
||||
hidden: bool = vespa_chunk.get(HIDDEN, False)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
global_boost: int = vespa_chunk[BOOST]
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
|
||||
if not semantic_identifier:
|
||||
raise ValueError("Missing semantic_identifier in Vespa chunk.")
|
||||
|
||||
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
|
||||
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
|
||||
blurb: str = vespa_chunk.get(BLURB, "")
|
||||
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
|
||||
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
|
||||
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
|
||||
document_sets: list[str] | None = (
|
||||
_transform_vespa_document_sets_to_opensearch_document_sets(
|
||||
vespa_chunk.get(DOCUMENT_SETS)
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
content: str = vespa_chunk[CONTENT]
|
||||
if not content:
|
||||
# This should exist; fail loudly if it does not; this function will
|
||||
# raise in that event.
|
||||
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_chunk.get(ACCESS_CONTROL_LIST)
|
||||
)
|
||||
|
||||
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
|
||||
if MULTI_TENANT:
|
||||
if not chunk_tenant_id:
|
||||
raise ValueError(
|
||||
f"Missing content in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
|
||||
)
|
||||
# This should exist; fail loudly if it does not.
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa
|
||||
# client in order to get a supported format for the tensors.
|
||||
content_vector: list[float] = _extract_content_vector(
|
||||
vespa_chunk[EMBEDDINGS]
|
||||
)
|
||||
if not content_vector:
|
||||
if chunk_tenant_id != tenant_state.tenant_id:
|
||||
raise ValueError(
|
||||
f"Missing content_vector in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
source_type: str = vespa_chunk[SOURCE_TYPE]
|
||||
if not source_type:
|
||||
raise ValueError(
|
||||
f"Missing source_type in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
)
|
||||
opensearch_chunk = DocumentChunk(
|
||||
# We deliberately choose to use the doc ID supplied to this function
|
||||
# over the Vespa doc ID.
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type=source_type,
|
||||
metadata_list=metadata_list,
|
||||
last_updated=last_updated,
|
||||
public=is_public,
|
||||
access_control_list=acl_list,
|
||||
hidden=hidden,
|
||||
global_boost=global_boost,
|
||||
semantic_identifier=semantic_identifier,
|
||||
image_file_id=image_file_id,
|
||||
source_links=source_links,
|
||||
blurb=blurb,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
|
||||
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
|
||||
result.append(opensearch_chunk)
|
||||
|
||||
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
|
||||
last_updated: datetime | None = (
|
||||
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
|
||||
if _raw_doc_updated_at is not None
|
||||
else None
|
||||
)
|
||||
|
||||
hidden: bool = vespa_chunk.get(HIDDEN, False)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
global_boost: int = vespa_chunk[BOOST]
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
|
||||
if not semantic_identifier:
|
||||
raise ValueError(
|
||||
f"Missing semantic_identifier in Vespa chunk with document ID {vespa_document_id} and chunk "
|
||||
f"index {chunk_index}."
|
||||
)
|
||||
|
||||
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
|
||||
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
|
||||
blurb: str = vespa_chunk.get(BLURB, "")
|
||||
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
|
||||
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
|
||||
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
|
||||
document_sets: list[str] | None = (
|
||||
_transform_vespa_document_sets_to_opensearch_document_sets(
|
||||
vespa_chunk.get(DOCUMENT_SETS)
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_chunk.get(ACCESS_CONTROL_LIST)
|
||||
)
|
||||
if not is_public and not acl_list:
|
||||
logger.warning(
|
||||
f"Migration warning: Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index} has no "
|
||||
"public ACL and no access control list. This does not make sense as it implies the document is never "
|
||||
"searchable. Continuing with the migration..."
|
||||
)
|
||||
|
||||
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
|
||||
if MULTI_TENANT:
|
||||
if not chunk_tenant_id:
|
||||
raise ValueError(
|
||||
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
|
||||
)
|
||||
if chunk_tenant_id != tenant_state.tenant_id:
|
||||
raise ValueError(
|
||||
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
|
||||
)
|
||||
|
||||
opensearch_chunk = DocumentChunk(
|
||||
# We deliberately choose to use the doc ID supplied to this function
|
||||
# over the Vespa doc ID.
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type=source_type,
|
||||
metadata_list=metadata_list,
|
||||
last_updated=last_updated,
|
||||
public=is_public,
|
||||
access_control_list=acl_list,
|
||||
hidden=hidden,
|
||||
global_boost=global_boost,
|
||||
semantic_identifier=semantic_identifier,
|
||||
image_file_id=image_file_id,
|
||||
source_links=source_links,
|
||||
blurb=blurb,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
|
||||
result.append(opensearch_chunk)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.exception(
|
||||
f"Migration error: Error transforming Vespa chunk with document ID {vespa_chunk.get(DOCUMENT_ID)} "
|
||||
f"and chunk index {vespa_chunk.get(CHUNK_ID)} into an OpenSearch chunk. Continuing with "
|
||||
"the migration..."
|
||||
)
|
||||
errored_chunks.append(vespa_chunk)
|
||||
|
||||
return result, errored_chunks
|
||||
return result
|
||||
|
||||
@@ -43,11 +43,9 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
@@ -55,9 +53,6 @@ from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
@@ -528,47 +523,12 @@ def connector_pruning_generator_task(
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=cc_pair.connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=cc_pair.connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning: persisted and cached {len(extraction_result.hierarchy_nodes)} "
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
|
||||
@@ -10,26 +10,21 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -42,7 +37,6 @@ from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
@@ -60,17 +54,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -134,24 +117,7 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -166,21 +132,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -193,35 +145,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -229,138 +158,11 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _process_user_file_without_vector_db(
|
||||
uf: UserFile,
|
||||
documents: list[Document],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Process a user file when the vector DB is disabled.
|
||||
|
||||
Extracts raw text and computes a token count, stores the plaintext in
|
||||
the file store, and marks the file as COMPLETED. Skips embedding and
|
||||
the indexing pipeline entirely.
|
||||
"""
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_tokenizer_encode_func
|
||||
|
||||
# Combine section text from all document sections
|
||||
combined_text = " ".join(
|
||||
section.text for doc in documents for section in doc.sections if section.text
|
||||
)
|
||||
|
||||
# Compute token count using the user's default LLM tokenizer
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
encode = get_llm_tokenizer_encode_func(llm)
|
||||
token_count: int | None = len(encode(combined_text))
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Failed to compute token count for {uf.id}, falling back to None"
|
||||
)
|
||||
token_count = None
|
||||
|
||||
# Persist plaintext for fast FileReaderTool loads
|
||||
store_user_file_plaintext(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content=combined_text,
|
||||
)
|
||||
|
||||
# Update the DB record
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.COMPLETED
|
||||
uf.token_count = token_count
|
||||
uf.chunk_count = 0 # no chunks without vector DB
|
||||
uf.last_project_sync_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Completed id={uf.id} tokens={token_count}"
|
||||
)
|
||||
|
||||
|
||||
def _process_user_file_with_indexing(
|
||||
uf: UserFile,
|
||||
user_file_id: str,
|
||||
documents: list[Document],
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Process a user file through the full indexing pipeline (vector DB path)."""
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
current_search_settings = next(
|
||||
(ss for ss in search_settings_list if ss.status.is_current()),
|
||||
None,
|
||||
)
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
@@ -373,12 +175,6 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
@@ -412,31 +208,93 @@ def process_single_user_file(
|
||||
)
|
||||
connector.load_credentials({})
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
current_search_settings = next(
|
||||
(
|
||||
search_settings_instance
|
||||
for search_settings_instance in search_settings_list
|
||||
if search_settings_instance.status.is_current()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
for batch in connector.load_from_state():
|
||||
documents.extend(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set up indexing pipeline components
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
# This flow is for indexing so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# update the doument id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
_process_user_file_without_vector_db(
|
||||
uf=uf,
|
||||
documents=documents,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
_process_user_file_with_indexing(
|
||||
uf=uf,
|
||||
user_file_id=user_file_id,
|
||||
documents=documents,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
# Re-fetch to avoid mypy error
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
and current_user_file.status != UserFileStatus.DELETING
|
||||
):
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
@@ -550,6 +408,28 @@ def process_single_user_file_delete(
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
# This flow is for deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
@@ -557,43 +437,22 @@ def process_single_user_file_delete(
|
||||
)
|
||||
return None
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
# 1) Delete Vespa chunks for the document
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
@@ -705,6 +564,27 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
@@ -712,35 +592,15 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
return None
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user