mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 02:05:46 +00:00
Compare commits
1 Commits
main
...
litellm_pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ac1f12af5 |
@@ -1 +0,0 @@
|
||||
../.cursor/skills
|
||||
@@ -1,248 +0,0 @@
|
||||
---
|
||||
name: playwright-e2e-tests
|
||||
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
|
||||
---
|
||||
|
||||
# Playwright E2E Tests
|
||||
|
||||
## Project Layout
|
||||
|
||||
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
|
||||
- **Config**: `web/playwright.config.ts`
|
||||
- **Utilities**: `web/tests/e2e/utils/`
|
||||
- **Constants**: `web/tests/e2e/constants.ts`
|
||||
- **Global setup**: `web/tests/e2e/global-setup.ts`
|
||||
- **Output**: `web/output/playwright/`
|
||||
|
||||
## Imports
|
||||
|
||||
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
|
||||
```
|
||||
|
||||
All new files should be `.ts`, not `.js`.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test file
|
||||
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
|
||||
|
||||
# Run a specific project
|
||||
npx playwright test --project admin
|
||||
npx playwright test --project exclusive
|
||||
```
|
||||
|
||||
## Test Projects
|
||||
|
||||
| Project | Description | Parallelism |
|
||||
|---------|-------------|-------------|
|
||||
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
|
||||
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
|
||||
|
||||
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
|
||||
|
||||
## Authentication
|
||||
|
||||
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
|
||||
|
||||
- Server readiness check (polls health endpoint, 60s timeout)
|
||||
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
|
||||
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
|
||||
- Setting display name to `"worker"` for each worker user
|
||||
- Promoting admin2 to admin role
|
||||
- Ensuring a public LLM provider exists
|
||||
|
||||
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
|
||||
|
||||
When a test needs a different user, use API-based login — never drive the login UI:
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin2");
|
||||
|
||||
// Log in as the worker-specific user (preferred for test isolation):
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
Tests start pre-authenticated as admin — navigate and test directly:
|
||||
|
||||
```typescript
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
test.describe("Feature Name", () => {
|
||||
test("should describe expected behavior clearly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// Already authenticated as admin — go straight to testing
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
|
||||
|
||||
```typescript
|
||||
import { test } from "@playwright/test";
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.beforeEach(async ({ page }, testInfo) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
});
|
||||
```
|
||||
|
||||
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
|
||||
|
||||
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
|
||||
|
||||
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
|
||||
|
||||
## Key Utilities
|
||||
|
||||
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
|
||||
|
||||
Backend API client for test setup/teardown. Key methods:
|
||||
|
||||
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
|
||||
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
|
||||
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
|
||||
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
|
||||
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
|
||||
- **Chat**: `createChatSession()`, `deleteChatSession()`
|
||||
|
||||
### `chatActions` (`@tests/e2e/utils/chatActions`)
|
||||
|
||||
- `sendMessage(page, message)` — sends a message and waits for AI response
|
||||
- `startNewChat(page)` — clicks new-chat button and waits for intro
|
||||
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
|
||||
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
|
||||
- `switchModel(page, modelName)` — switches LLM model via popover
|
||||
|
||||
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
|
||||
|
||||
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
|
||||
- `expectElementScreenshot(locator, { name, mask?, hide? })`
|
||||
- Controlled by `VISUAL_REGRESSION=true` env var
|
||||
|
||||
### `theme` (`@tests/e2e/utils/theme`)
|
||||
|
||||
- `THEMES` — `["light", "dark"] as const` array for iterating over both themes
|
||||
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
|
||||
|
||||
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
|
||||
|
||||
```typescript
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Feature (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
test("renders correctly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await expectScreenshot(page, { name: `feature-${theme}` });
|
||||
});
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### `tools` (`@tests/e2e/utils/tools`)
|
||||
|
||||
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
|
||||
- `openActionManagement(page)` — opens the tool management popover
|
||||
|
||||
## Locator Strategy
|
||||
|
||||
Use locators in this priority order:
|
||||
|
||||
1. **`data-testid` / `aria-label`** — preferred for Onyx components
|
||||
```typescript
|
||||
page.getByTestId("AppSidebar/new-session")
|
||||
page.getByLabel("admin-page-title")
|
||||
```
|
||||
|
||||
2. **Role-based** — for standard HTML elements
|
||||
```typescript
|
||||
page.getByRole("button", { name: "Create" })
|
||||
page.getByRole("dialog")
|
||||
```
|
||||
|
||||
3. **Text/Label** — for visible text content
|
||||
```typescript
|
||||
page.getByText("Custom Assistant")
|
||||
page.getByLabel("Email")
|
||||
```
|
||||
|
||||
4. **CSS selectors** — last resort, only when above won't work
|
||||
```typescript
|
||||
page.locator('input[name="name"]')
|
||||
page.locator("#onyx-chat-input-textarea")
|
||||
```
|
||||
|
||||
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
|
||||
|
||||
## Assertions
|
||||
|
||||
Use web-first assertions — they auto-retry until the condition is met:
|
||||
|
||||
```typescript
|
||||
// Visibility
|
||||
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Text content
|
||||
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
|
||||
|
||||
// Count
|
||||
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
|
||||
|
||||
// URL
|
||||
await expect(page).toHaveURL(/chatId=/);
|
||||
|
||||
// Element state
|
||||
await expect(toggle).toBeChecked();
|
||||
await expect(button).toBeEnabled();
|
||||
```
|
||||
|
||||
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
|
||||
|
||||
## Waiting Strategy
|
||||
|
||||
```typescript
|
||||
// Wait for load state after navigation
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Wait for specific element
|
||||
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Wait for URL change
|
||||
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
|
||||
|
||||
// Wait for network response
|
||||
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
|
||||
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
|
||||
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
|
||||
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
|
||||
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
|
||||
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
|
||||
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
|
||||
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
|
||||
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
|
||||
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does
|
||||
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -640,7 +640,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
@@ -722,7 +721,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
@@ -45,6 +45,9 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -115,10 +118,9 @@ jobs:
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -127,6 +129,7 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
|
||||
5
.github/workflows/pr-helm-chart-testing.yml
vendored
5
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,8 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
@@ -91,7 +92,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
|
||||
38
.github/workflows/pr-integration-tests.yml
vendored
38
.github/workflows/pr-integration-tests.yml
vendored
@@ -46,7 +46,6 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
editions: ${{ steps.set-editions.outputs.editions }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -73,16 +72,6 @@ jobs:
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Determine editions to test
|
||||
id: set-editions
|
||||
run: |
|
||||
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
@@ -278,7 +267,7 @@ jobs:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
|
||||
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
@@ -286,7 +275,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -310,11 +298,12 @@ jobs:
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
EDITION: ${{ matrix.edition }}
|
||||
run: |
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -323,20 +312,11 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
# EE-only config
|
||||
if [ "$EDITION" = "ee" ]; then
|
||||
cat <<EOF >> deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -399,14 +379,14 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -464,7 +444,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
443
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
443
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
@@ -0,0 +1,443 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
wait_for_service() {
|
||||
local url=$1
|
||||
local label=$2
|
||||
local timeout=${3:-300} # default 5 minutes
|
||||
local start_time
|
||||
start_time=$(date +%s)
|
||||
|
||||
while true; do
|
||||
local current_time
|
||||
current_time=$(date +%s)
|
||||
local elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
local response
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "${label} is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
|
||||
else
|
||||
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
9
.github/workflows/pr-playwright-tests.yml
vendored
9
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,9 +22,6 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -303,7 +300,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
@@ -593,10 +589,7 @@ jobs:
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: >-
|
||||
always() &&
|
||||
github.event_name == 'pull_request' &&
|
||||
needs.playwright-tests.result != 'cancelled'
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
|
||||
13
.github/workflows/zizmor.yml
vendored
13
.github/workflows/zizmor.yml
vendored
@@ -5,8 +5,6 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
paths:
|
||||
- ".github/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
@@ -23,18 +21,29 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,7 +7,6 @@
|
||||
.zed
|
||||
.cursor
|
||||
!/.cursor/mcp.json
|
||||
!/.cursor/skills/
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
|
||||
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -275,7 +275,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -419,7 +419,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching"
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,10 +2,7 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""add scim_username to scim_user_mapping
|
||||
|
||||
Revision ID: 0bb4558f35df
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-20 10:45:30.340188
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0bb4558f35df"
|
||||
down_revision = "631fd2504136"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_username", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_username")
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Migrate to contextual rag model
|
||||
|
||||
Revision ID: 19c0ccb01687
|
||||
Revises: 9c54986124c6
|
||||
Create Date: 2026-02-12 11:21:41.798037
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "19c0ccb01687"
|
||||
down_revision = "9c54986124c6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
|
||||
# when the table was created with only CHAT/VISION values.
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=20),
|
||||
existing_type=sa.String(length=10),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# For every search_settings row that has contextual rag configured,
|
||||
# create an llm_model_flow entry. is_default is TRUE if the row
|
||||
# belongs to the PRESENT search settings, FALSE otherwise.
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
|
||||
SELECT DISTINCT
|
||||
'CONTEXTUAL_RAG',
|
||||
mc.id,
|
||||
(ss.status = 'PRESENT')
|
||||
FROM search_settings ss
|
||||
JOIN llm_provider lp
|
||||
ON lp.name = ss.contextual_rag_llm_provider
|
||||
JOIN model_configuration mc
|
||||
ON mc.llm_provider_id = lp.id
|
||||
AND mc.name = ss.contextual_rag_llm_name
|
||||
WHERE ss.enable_contextual_rag = TRUE
|
||||
AND ss.contextual_rag_llm_name IS NOT NULL
|
||||
AND ss.contextual_rag_llm_provider IS NOT NULL
|
||||
ON CONFLICT (llm_model_flow_type, model_configuration_id)
|
||||
DO UPDATE SET is_default = EXCLUDED.is_default
|
||||
WHERE EXCLUDED.is_default = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM llm_model_flow
|
||||
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=10),
|
||||
existing_type=sa.String(length=20),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add approx_chunk_count_in_vespa to opensearch tenant migration
|
||||
|
||||
Revision ID: 631fd2504136
|
||||
Revises: c7f2e1b4a9d3
|
||||
Create Date: 2026-02-18 21:07:52.831215
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "631fd2504136"
|
||||
down_revision = "c7f2e1b4a9d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"approx_chunk_count_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")
|
||||
@@ -1,124 +0,0 @@
|
||||
"""add_scim_tables
|
||||
|
||||
Revision ID: 9c54986124c6
|
||||
Revises: b51c6844d1df
|
||||
Create Date: 2026-02-12 20:29:47.448614
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c54986124c6"
|
||||
down_revision = "b51c6844d1df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scim_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("hashed_token", sa.String(length=64), nullable=False),
|
||||
sa.Column("token_display", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_by_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("hashed_token"),
|
||||
)
|
||||
op.create_table(
|
||||
"scim_group_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_group_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
"scim_group_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"scim_user_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
"scim_user_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
table_name="scim_user_mapping",
|
||||
)
|
||||
op.drop_table("scim_user_mapping")
|
||||
op.drop_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
table_name="scim_group_mapping",
|
||||
)
|
||||
op.drop_table("scim_group_mapping")
|
||||
op.drop_table("scim_token")
|
||||
@@ -1,31 +0,0 @@
|
||||
"""add sharing_scope to build_session
|
||||
|
||||
Revision ID: c7f2e1b4a9d3
|
||||
Revises: 19c0ccb01687
|
||||
Create Date: 2026-02-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c7f2e1b4a9d3"
|
||||
down_revision = "19c0ccb01687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"build_session",
|
||||
sa.Column(
|
||||
"sharing_scope",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("build_session", "sharing_scope")
|
||||
@@ -1,20 +1,20 @@
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -536,9 +536,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(
|
||||
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
|
||||
)
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
|
||||
# pass in the capability to fetch all existing docs for the cc_pair
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
@@ -578,13 +576,6 @@ def connector_permission_sync_generator_task(
|
||||
tasks_generated = 0
|
||||
docs_with_errors = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
f"Permission sync task timed out or stop signal detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
result = redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
@@ -941,7 +932,6 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
@@ -954,26 +944,11 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"PermissionSyncCallback - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
|
||||
@@ -466,7 +466,6 @@ def connector_external_group_sync_generator_task(
|
||||
def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
timeout_seconds: int = JOB_TIMEOUT,
|
||||
) -> None:
|
||||
# Create attempt record at the start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -519,23 +518,9 @@ def _perform_external_group_sync(
|
||||
seen_users: set[str] = set() # Track unique users across all groups
|
||||
total_groups_processed = 0
|
||||
total_group_memberships_synced = 0
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
raise RuntimeError(
|
||||
f"External group sync task timed out: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"elapsed={elapsed:.0f}s "
|
||||
f"timeout={timeout_seconds}s "
|
||||
f"groups_processed={total_groups_processed}"
|
||||
)
|
||||
|
||||
external_user_group_batch.append(external_user_group)
|
||||
|
||||
# Track progress
|
||||
|
||||
@@ -263,15 +263,9 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=source,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,604 +0,0 @@
|
||||
"""SCIM Data Access Layer.
|
||||
|
||||
All database operations for SCIM provisioning — token management, user
|
||||
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
|
||||
|
||||
Usage from FastAPI::
|
||||
|
||||
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
@router.post("/tokens")
|
||||
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
|
||||
token = dal.create_token(name=..., hashed_token=..., ...)
|
||||
dal.commit()
|
||||
return token
|
||||
|
||||
Usage from background tasks::
|
||||
|
||||
with ScimDAL.from_tenant("tenant_abc") as dal:
|
||||
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
|
||||
dal.commit()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import SQLColumnExpression
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimDAL(DAL):
|
||||
"""Data Access Layer for SCIM provisioning operations.
|
||||
|
||||
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
|
||||
when you want to persist changes. This follows the existing ``_no_commit``
|
||||
convention and lets callers batch multiple operations into one transaction.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
name: str,
|
||||
hashed_token: str,
|
||||
token_display: str,
|
||||
created_by_id: UUID,
|
||||
) -> ScimToken:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — this method automatically revokes
|
||||
all existing active tokens before creating the new one.
|
||||
"""
|
||||
# Revoke any currently active tokens
|
||||
active_tokens = list(
|
||||
self._session.scalars(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
).all()
|
||||
)
|
||||
for t in active_tokens:
|
||||
t.is_active = False
|
||||
|
||||
token = ScimToken(
|
||||
name=name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=created_by_id,
|
||||
)
|
||||
self._session.add(token)
|
||||
self._session.flush()
|
||||
return token
|
||||
|
||||
def get_active_token(self) -> ScimToken | None:
|
||||
"""Return the single currently active token, or None."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
)
|
||||
|
||||
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
|
||||
"""Look up a token by its SHA-256 hash."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
|
||||
)
|
||||
|
||||
def revoke_token(self, token_id: int) -> None:
|
||||
"""Deactivate a token by ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token does not exist.
|
||||
"""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if not token:
|
||||
raise ValueError(f"SCIM token with id {token_id} not found")
|
||||
token.is_active = False
|
||||
|
||||
def update_token_last_used(self, token_id: int) -> None:
|
||||
"""Update the last_used_at timestamp for a token."""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if token:
|
||||
token.last_used_at = func.now() # type: ignore[assignment]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_user_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the Onyx user ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
|
||||
)
|
||||
|
||||
def list_user_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimUserMapping], int]:
|
||||
"""List user mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimUserMapping)
|
||||
.order_by(ScimUserMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def update_user_mapping_external_id(
|
||||
self,
|
||||
mapping_id: int,
|
||||
external_id: str,
|
||||
) -> ScimUserMapping:
|
||||
"""Update the external ID on a user mapping.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mapping does not exist.
|
||||
"""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
|
||||
mapping.external_id = external_id
|
||||
return mapping
|
||||
|
||||
def delete_user_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a user mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_user(self, user_id: UUID) -> User | None:
|
||||
"""Fetch a user by ID."""
|
||||
return self._session.scalar(
|
||||
select(User).where(User.id == user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Fetch a user by email (case-insensitive)."""
|
||||
return self._session.scalar(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
|
||||
def add_user(self, user: User) -> None:
|
||||
"""Add a new user to the session and flush to assign an ID."""
|
||||
self._session.add(user)
|
||||
self._session.flush()
|
||||
|
||||
def update_user(
|
||||
self,
|
||||
user: User,
|
||||
*,
|
||||
email: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
personal_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update user attributes. Only sets fields that are provided."""
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if is_active is not None:
|
||||
user.is_active = is_active
|
||||
if personal_name is not None:
|
||||
user.personal_name = personal_name
|
||||
|
||||
def deactivate_user(self, user: User) -> None:
|
||||
"""Mark a user as inactive."""
|
||||
user.is_active = False
|
||||
|
||||
def list_users(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "username":
|
||||
# arg-type: fastapi-users types User.email as str, not a column expression
|
||||
# assignment: union return type widens but query is still Select[tuple[User]]
|
||||
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
|
||||
elif attr == "active":
|
||||
query = query.where(
|
||||
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
|
||||
)
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
# Count total matching rows first, then paginate. SCIM uses 1-based
|
||||
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
).all()
|
||||
)
|
||||
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_group_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_group_id: int,
|
||||
) -> ScimGroupMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user group."""
|
||||
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_group_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_group_mapping_by_group_id(
|
||||
self, user_group_id: int
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the Onyx user group ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
|
||||
def list_group_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimGroupMapping], int]:
|
||||
"""List group mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimGroupMapping)
|
||||
.order_by(ScimGroupMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def delete_group_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a group mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimGroupMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_group(self, group_id: int) -> UserGroup | None:
|
||||
"""Fetch a group by ID, returning None if deleted or missing."""
|
||||
group = self._session.get(UserGroup, group_id)
|
||||
if group and group.is_up_for_deletion:
|
||||
return None
|
||||
return group
|
||||
|
||||
def get_group_by_name(self, name: str) -> UserGroup | None:
|
||||
"""Fetch a group by exact name."""
|
||||
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
|
||||
|
||||
def add_group(self, group: UserGroup) -> None:
|
||||
"""Add a new group to the session and flush to assign an ID."""
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Update group attributes and set the modification timestamp."""
|
||||
if name is not None:
|
||||
group.name = name
|
||||
group.time_last_modified_by_user = func.now()
|
||||
|
||||
def delete_group(self, group: UserGroup) -> None:
|
||||
"""Delete a group from the session."""
|
||||
self._session.delete(group)
|
||||
|
||||
def list_groups(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[UserGroup, str | None]], int]:
|
||||
"""Query groups with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (group, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "displayname":
|
||||
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
|
||||
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(UserGroup.id == mapping.user_group_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
groups = list(
|
||||
self._session.scalars(
|
||||
query.order_by(UserGroup.id).offset(offset).limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
ext_id_map = self._get_group_external_ids([g.id for g in groups])
|
||||
return [(g, ext_id_map.get(g.id)) for g in groups], total
|
||||
|
||||
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
|
||||
"""Get group members as (user_id, email) pairs."""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
).all()
|
||||
|
||||
user_ids = [r.user_id for r in rels if r.user_id]
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
(
|
||||
r.user_id,
|
||||
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
|
||||
)
|
||||
for r in rels
|
||||
if r.user_id
|
||||
]
|
||||
|
||||
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
|
||||
"""Return the subset of UUIDs that don't exist as users.
|
||||
|
||||
Returns an empty list if all IDs are valid.
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Add user-group relationships, ignoring duplicates."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
pg_insert(User__UserGroup)
|
||||
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
User__UserGroup.user_group_id,
|
||||
User__UserGroup.user_id,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Replace all members of a group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
self.upsert_group_members(group_id, user_ids)
|
||||
|
||||
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Remove specific members from a group."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.in_(user_ids),
|
||||
)
|
||||
)
|
||||
|
||||
def delete_group_with_members(self, group: UserGroup) -> None:
|
||||
"""Remove all member relationships and delete the group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
|
||||
)
|
||||
self._session.delete(group)
|
||||
|
||||
def sync_group_external_id(
|
||||
self, group_id: int, new_external_id: str | None
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a group."""
|
||||
mapping = self.get_group_mapping_by_group_id(group_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_group_mapping(
|
||||
external_id=new_external_id, user_group_id=group_id
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_group_mapping(mapping.id)
|
||||
|
||||
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
|
||||
"""Batch-fetch external IDs for a list of group IDs."""
|
||||
if not group_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id.in_(group_ids)
|
||||
)
|
||||
).all()
|
||||
return {m.user_group_id: m.external_id for m in mappings}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (used by DAL methods above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_scim_string_op(
|
||||
query: Select[tuple[User]] | Select[tuple[UserGroup]],
|
||||
column: SQLColumnExpression[str],
|
||||
scim_filter: ScimFilter,
|
||||
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
|
||||
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
|
||||
|
||||
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
|
||||
SQLAlchemy's operators handle LIKE-pattern escaping internally.
|
||||
"""
|
||||
val = scim_filter.value
|
||||
if scim_filter.operator == ScimFilterOperator.EQUAL:
|
||||
return query.where(func.lower(column) == val.lower())
|
||||
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
|
||||
return query.where(column.icontains(val, autoescape=True))
|
||||
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
|
||||
return query.where(column.istartswith(val, autoescape=True))
|
||||
else:
|
||||
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -266,8 +209,6 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -275,16 +216,11 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
|
||||
@@ -6,7 +6,6 @@ from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -47,27 +46,19 @@ def sharepoint_group_sync(
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
enumerate_all = connector_config.get(
|
||||
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
)
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
sp_domain_suffix = connector.sharepoint_domain_suffix
|
||||
# Process each site
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
# Create client context for the site using connector's MSAL app
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
|
||||
external_groups = get_sharepoint_external_groups(
|
||||
ctx,
|
||||
connector.graph_client,
|
||||
graph_api_base=connector.graph_api_base,
|
||||
get_access_token=connector._get_graph_access_token,
|
||||
enumerate_all_ad_groups=enumerate_all,
|
||||
)
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
|
||||
# Yield each group
|
||||
for group in external_groups:
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import re
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
@@ -18,10 +14,7 @@ from pydantic import BaseModel
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
|
||||
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
|
||||
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
|
||||
|
||||
|
||||
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
|
||||
|
||||
|
||||
def _graph_api_get(
|
||||
url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Authenticated Graph API GET with retry on transient errors."""
|
||||
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
|
||||
access_token = get_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
resp = _requests.get(
|
||||
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
)
|
||||
if (
|
||||
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
|
||||
and attempt < GRAPH_API_MAX_RETRIES
|
||||
):
|
||||
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
|
||||
logger.warning(
|
||||
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
f"Graph API connection error on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
|
||||
)
|
||||
|
||||
|
||||
def _iter_graph_collection(
|
||||
initial_url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Paginate through a Graph API collection, yielding items one at a time."""
|
||||
url: str | None = initial_url
|
||||
while url:
|
||||
data = _graph_api_get(url, get_access_token, params)
|
||||
params = None
|
||||
yield from data.get("value", [])
|
||||
url = data.get("@odata.nextLink")
|
||||
|
||||
|
||||
def _normalize_email(email: str) -> str:
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
return email.replace(MICROSOFT_DOMAIN, "")
|
||||
return email
|
||||
|
||||
|
||||
class SharepointGroup(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
|
||||
def _enumerate_ad_groups_paginated(
|
||||
get_access_token: Callable[[], str],
|
||||
already_resolved: set[str],
|
||||
graph_api_base: str,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
|
||||
|
||||
Skips groups whose suffixed name is already in *already_resolved*.
|
||||
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
|
||||
"""
|
||||
groups_url = f"{graph_api_base}/groups"
|
||||
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
|
||||
total_groups = 0
|
||||
|
||||
for group_json in _iter_graph_collection(
|
||||
groups_url, get_access_token, groups_params
|
||||
):
|
||||
group_id: str = group_json.get("id", "")
|
||||
display_name: str = group_json.get("displayName", "")
|
||||
if not group_id or not display_name:
|
||||
continue
|
||||
|
||||
total_groups += 1
|
||||
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
|
||||
"groups — stopping to avoid excessive memory/API usage. "
|
||||
"Remaining groups will be resolved from role assignments only."
|
||||
)
|
||||
return
|
||||
|
||||
name = f"{display_name}_{group_id}"
|
||||
if name in already_resolved:
|
||||
continue
|
||||
|
||||
member_emails: list[str] = []
|
||||
members_url = f"{graph_api_base}/groups/{group_id}/members"
|
||||
members_params: dict[str, str] = {
|
||||
"$select": "userPrincipalName,mail",
|
||||
"$top": "999",
|
||||
}
|
||||
for member_json in _iter_graph_collection(
|
||||
members_url, get_access_token, members_params
|
||||
):
|
||||
email = member_json.get("userPrincipalName") or member_json.get("mail")
|
||||
if email:
|
||||
member_emails.append(_normalize_email(email))
|
||||
|
||||
yield ExternalUserGroup(id=name, user_emails=member_emails)
|
||||
|
||||
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
|
||||
|
||||
|
||||
def get_sharepoint_external_groups(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
graph_api_base: str,
|
||||
get_access_token: Callable[[], str] | None = None,
|
||||
enumerate_all_ad_groups: bool = False,
|
||||
client_context: ClientContext, graph_client: GraphClient
|
||||
) -> list[ExternalUserGroup]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
|
||||
client_context, graph_client, groups, is_group_sync=True
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = [
|
||||
ExternalUserGroup(id=group_name, user_emails=list(emails))
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items()
|
||||
]
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
graph_client.groups.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups",
|
||||
)
|
||||
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
|
||||
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
|
||||
ad_groups_to_emails: dict[str, set[str]] = {}
|
||||
for group in azure_ad_groups:
|
||||
# If the group is already identified, we don't need to get the members
|
||||
if group.display_name in identified_groups:
|
||||
continue
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
name = group.display_name
|
||||
name = _get_group_name_with_suffix(group.id, name, graph_client)
|
||||
|
||||
if not enumerate_all_ad_groups or get_access_token is None:
|
||||
logger.info(
|
||||
"Skipping exhaustive Azure AD group enumeration. "
|
||||
"Only groups found in site role assignments are included."
|
||||
members = sleep_and_retry(
|
||||
group.members.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
|
||||
)
|
||||
return external_user_groups
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
if not ad_groups_to_emails.get(name):
|
||||
ad_groups_to_emails[name] = set()
|
||||
if user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
ad_groups_to_emails[name].add(user_principal_name)
|
||||
elif mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
ad_groups_to_emails[name].add(mail)
|
||||
|
||||
already_resolved = set(groups_and_members.groups_to_emails.keys())
|
||||
for group in _enumerate_ad_groups_paginated(
|
||||
get_access_token, already_resolved, graph_api_base
|
||||
):
|
||||
external_user_groups.append(group)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
for group_name, emails in ad_groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
return external_user_groups
|
||||
|
||||
@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
@@ -163,11 +162,6 @@ def get_application() -> FastAPI:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
|
||||
@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
|
||||
|
||||
|
||||
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
|
||||
# before bearer token configuration is complete
|
||||
("/scim/v2/ServiceProviderConfig", {"GET"}),
|
||||
("/scim/v2/ResourceTypes", {"GET"}),
|
||||
("/scim/v2/Schemas", {"GET"}),
|
||||
# needs to be accessible prior to user login
|
||||
("/enterprise-settings", {"GET"}),
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
|
||||
@@ -13,7 +13,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import store_settings
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
|
||||
from ee.onyx.server.scim.models import ScimTokenResponse
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
@@ -203,63 +198,3 @@ def upload_custom_analytics_script(
|
||||
@basic_router.get("/custom-analytics-script")
|
||||
def fetch_custom_analytics_script() -> str | None:
|
||||
return load_analytics_script()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM token management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
@admin_router.get("/scim/token")
|
||||
def get_active_scim_token(
|
||||
_: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenResponse:
|
||||
"""Return the currently active SCIM token's metadata, or 404 if none."""
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/scim/token", status_code=201)
|
||||
def create_scim_token(
|
||||
body: ScimTokenCreate,
|
||||
user: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenCreatedResponse:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — creating a new token automatically
|
||||
revokes all previous tokens. The raw token value is returned exactly once
|
||||
in the response; it cannot be retrieved again.
|
||||
"""
|
||||
raw_token, hashed_token, token_display = generate_scim_token()
|
||||
token = dal.create_token(
|
||||
name=body.name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
dal.commit()
|
||||
|
||||
return ScimTokenCreatedResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
raw_token=raw_token,
|
||||
)
|
||||
|
||||
@@ -27,8 +27,6 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
|
||||
@@ -67,8 +67,6 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
|
||||
@@ -1,689 +0,0 @@
|
||||
"""SCIM 2.0 API endpoints (RFC 7644).
|
||||
|
||||
This module provides the FastAPI router for SCIM service discovery,
|
||||
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
|
||||
these endpoints to provision and manage users and groups.
|
||||
|
||||
Service discovery endpoints are unauthenticated — IdPs may probe them
|
||||
before bearer token configuration is complete. All other endpoints
|
||||
require a valid SCIM bearer token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/ServiceProviderConfig")
|
||||
def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
"""Advertise supported SCIM features (RFC 7643 §5)."""
|
||||
return SERVICE_PROVIDER_CONFIG
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return JSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
user = dal.get_user(uid)
|
||||
if not user:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
|
||||
def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""Extract a display name string from a SCIM name object.
|
||||
|
||||
Returns None if no name is provided, so the caller can decide
|
||||
whether to update the user's personal_name.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip().lower()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
try:
|
||||
dal.add_user(user)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip().lower(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, new_external_id)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
``PATCH {"active": false}`` rather than DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current = _user_to_scim(user, external_id)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
dal.deactivate_user(user)
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
group = dal.get_group(gid)
|
||||
if not group:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
return group
|
||||
|
||||
|
||||
def _parse_member_uuids(
|
||||
members: list[ScimGroupMember],
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse member value strings to UUIDs.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids: list[UUID] = []
|
||||
for m in members:
|
||||
try:
|
||||
uuids.append(UUID(m.value))
|
||||
except ValueError:
|
||||
return [], f"Invalid member ID: {m.value}"
|
||||
return uuids, None
|
||||
|
||||
|
||||
def _validate_and_parse_members(
|
||||
members: list[ScimGroupMember], dal: ScimDAL
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse and validate member UUIDs exist in the database.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids, err = _parse_member_uuids(members)
|
||||
if err:
|
||||
return [], err
|
||||
|
||||
if uuids:
|
||||
missing = dal.validate_member_ids(uuids)
|
||||
if missing:
|
||||
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
|
||||
|
||||
return uuids, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
db_group = UserGroup(
|
||||
name=group_resource.displayName,
|
||||
is_up_to_date=True,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
try:
|
||||
dal.add_group(db_group)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
missing = dal.validate_member_ids(add_uuids)
|
||||
if missing:
|
||||
return _scim_error_response(
|
||||
400,
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
def _is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
try:
|
||||
UUID(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -1,104 +0,0 @@
|
||||
"""SCIM bearer token authentication.
|
||||
|
||||
SCIM endpoints are authenticated via bearer tokens that admins create in the
|
||||
Onyx UI. This module provides:
|
||||
|
||||
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
|
||||
validates the token from the Authorization header.
|
||||
- ``generate_scim_token``: Creates a new cryptographically random token
|
||||
and returns the raw value, its SHA-256 hash, and a display suffix.
|
||||
|
||||
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
|
||||
URL-safe base64 from ``secrets.token_urlsafe``.
|
||||
|
||||
The hash is stored in the ``scim_token`` table; the raw value is shown to
|
||||
the admin exactly once at creation time.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
|
||||
def _hash_scim_token(token: str) -> str:
|
||||
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def generate_scim_token() -> tuple[str, str, str]:
|
||||
"""Generate a new SCIM bearer token.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(raw_token, hashed_token, token_display)`` where
|
||||
``token_display`` is a masked version showing only the last 4 chars.
|
||||
"""
|
||||
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
|
||||
hashed_token = _hash_scim_token(raw_token)
|
||||
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
|
||||
return raw_token, hashed_token, token_display
|
||||
|
||||
|
||||
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
|
||||
"""Extract and hash a SCIM token from the request Authorization header."""
|
||||
return get_hashed_bearer_token_from_request(
|
||||
request,
|
||||
valid_prefixes=[SCIM_TOKEN_PREFIX],
|
||||
hash_fn=_hash_scim_token,
|
||||
)
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
def verify_scim_token(
|
||||
request: Request,
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimToken:
|
||||
"""FastAPI dependency that authenticates SCIM requests.
|
||||
|
||||
Extracts the bearer token from the Authorization header, hashes it,
|
||||
looks it up in the database, and verifies it is active.
|
||||
|
||||
Note:
|
||||
This dependency does NOT update ``last_used_at`` — the endpoint
|
||||
should do that via ``ScimDAL.update_token_last_used()`` so the
|
||||
timestamp write is part of the endpoint's transaction.
|
||||
|
||||
Raises:
|
||||
HTTPException(401): If the token is missing, invalid, or inactive.
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
|
||||
return token
|
||||
@@ -1,96 +0,0 @@
|
||||
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
|
||||
|
||||
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
|
||||
resources before deciding whether to create or update them. For example, when
|
||||
an admin assigns a user to the Onyx app, the IdP first checks whether that
|
||||
user already exists::
|
||||
|
||||
GET /scim/v2/Users?filter=userName eq "john@example.com"
|
||||
|
||||
If zero results come back the IdP creates the user (``POST``); if a match is
|
||||
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
|
||||
The same pattern applies to groups (``displayName eq "Engineering"``).
|
||||
|
||||
This module parses the subset of the SCIM filter grammar that identity
|
||||
providers actually send in practice:
|
||||
|
||||
attribute SP operator SP value
|
||||
|
||||
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
|
||||
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
|
||||
the parser returns ``None`` and the caller falls back to an unfiltered list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScimFilterOperator(str, Enum):
|
||||
"""Supported SCIM filter operators."""
|
||||
|
||||
EQUAL = "eq"
|
||||
CONTAINS = "co"
|
||||
STARTS_WITH = "sw"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScimFilter:
|
||||
"""Parsed SCIM filter expression."""
|
||||
|
||||
attribute: str
|
||||
operator: ScimFilterOperator
|
||||
value: str
|
||||
|
||||
|
||||
# Matches: attribute operator "value" (with or without quotes around value)
|
||||
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
|
||||
_FILTER_RE = re.compile(
|
||||
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
|
||||
r'(?:"([^"]*)"' # quoted value
|
||||
r"|'([^']*)')" # or single-quoted value
|
||||
r"$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
|
||||
"""Parse a simple SCIM filter expression.
|
||||
|
||||
Args:
|
||||
filter_string: Raw filter query parameter value, e.g.
|
||||
``'userName eq "john@example.com"'``
|
||||
|
||||
Returns:
|
||||
A ``ScimFilter`` if the expression is valid and uses a supported
|
||||
operator, or ``None`` if the input is empty / missing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter string is present but malformed or uses
|
||||
an unsupported operator.
|
||||
"""
|
||||
if not filter_string or not filter_string.strip():
|
||||
return None
|
||||
|
||||
match = _FILTER_RE.match(filter_string.strip())
|
||||
if not match:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
|
||||
|
||||
return _build_filter(match, filter_string)
|
||||
|
||||
|
||||
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
|
||||
"""Extract fields from a regex match and construct a ScimFilter."""
|
||||
attribute = match.group(1)
|
||||
op_str = match.group(2).lower()
|
||||
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
|
||||
value = match.group(3) if match.group(3) is not None else match.group(4)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
|
||||
|
||||
operator = ScimFilterOperator(op_str)
|
||||
|
||||
return ScimFilter(attribute=attribute, operator=operator, value=value)
|
||||
@@ -1,285 +0,0 @@
|
||||
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
|
||||
|
||||
SCIM protocol schemas follow the wire format defined in:
|
||||
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
|
||||
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
|
||||
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Schema URIs (RFC 7643 §8)
|
||||
# Every SCIM JSON payload includes a "schemas" array identifying its type.
|
||||
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
|
||||
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Protocol Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimName(BaseModel):
|
||||
"""User name components (RFC 7643 §4.1.1)."""
|
||||
|
||||
givenName: str | None = None
|
||||
familyName: str | None = None
|
||||
formatted: str | None = None
|
||||
|
||||
|
||||
class ScimEmail(BaseModel):
|
||||
"""Email sub-attribute (RFC 7643 §4.1.2)."""
|
||||
|
||||
value: str
|
||||
type: str | None = None
|
||||
primary: bool = False
|
||||
|
||||
|
||||
class ScimMeta(BaseModel):
|
||||
"""Resource metadata (RFC 7643 §3.1)."""
|
||||
|
||||
resourceType: str | None = None
|
||||
created: datetime | None = None
|
||||
lastModified: datetime | None = None
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
This is the JSON shape that IdPs send when creating/updating a user via
|
||||
SCIM, and the shape we return in GET responses. Field names use camelCase
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
"""Group member reference (RFC 7643 §4.2).
|
||||
|
||||
Represents a user within a SCIM group. The IdP sends these when adding
|
||||
or removing users from groups. ``value`` is the Onyx user ID.
|
||||
"""
|
||||
|
||||
value: str # User ID of the group member
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimGroupResource(BaseModel):
|
||||
"""SCIM Group resource representation (RFC 7643 §4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
|
||||
id: str | None = None
|
||||
externalId: str | None = None
|
||||
displayName: str
|
||||
members: list[ScimGroupMember] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimListResponse(BaseModel):
|
||||
"""Paginated list response (RFC 7644 §3.4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
|
||||
totalResults: int
|
||||
startIndex: int = 1
|
||||
itemsPerPage: int = 100
|
||||
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimPatchOperationType(str, Enum):
|
||||
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
|
||||
|
||||
ADD = "add"
|
||||
REPLACE = "replace"
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
IdPs use PATCH to make incremental changes — e.g. deactivating a user
|
||||
(replace active=false) or adding/removing group members — instead of
|
||||
replacing the entire resource with PUT.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
|
||||
Operations: list[ScimPatchOperation]
|
||||
|
||||
|
||||
class ScimError(BaseModel):
|
||||
"""SCIM error response (RFC 7644 §3.12)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
|
||||
status: str
|
||||
detail: str | None = None
|
||||
scimType: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Provider Configuration (RFC 7643 §5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimSupported(BaseModel):
|
||||
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
|
||||
|
||||
supported: bool
|
||||
|
||||
|
||||
class ScimFilterConfig(BaseModel):
|
||||
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
|
||||
|
||||
supported: bool
|
||||
maxResults: int = 100
|
||||
|
||||
|
||||
class ScimServiceProviderConfig(BaseModel):
|
||||
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
|
||||
|
||||
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
|
||||
initial setup to discover which SCIM features our server supports
|
||||
(e.g. PATCH yes, bulk no, filtering yes).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(
|
||||
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
|
||||
)
|
||||
patch: ScimSupported = ScimSupported(supported=True)
|
||||
bulk: ScimSupported = ScimSupported(supported=False)
|
||||
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
|
||||
changePassword: ScimSupported = ScimSupported(supported=False)
|
||||
sort: ScimSupported = ScimSupported(supported=False)
|
||||
etag: ScimSupported = ScimSupported(supported=False)
|
||||
authenticationSchemes: list[dict[str, str]] = Field(
|
||||
default_factory=lambda: [
|
||||
{
|
||||
"type": "oauthbearertoken",
|
||||
"name": "OAuth Bearer Token",
|
||||
"description": "Authentication scheme using a SCIM bearer token",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaAttribute(BaseModel):
|
||||
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
multiValued: bool = False
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
caseExact: bool = False
|
||||
mutability: str = "readWrite"
|
||||
returned: str = "default"
|
||||
uniqueness: str = "none"
|
||||
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaDefinition(BaseModel):
|
||||
"""SCIM Schema definition (RFC 7643 §7).
|
||||
|
||||
Served at GET /scim/v2/Schemas. Describes the attributes available
|
||||
on each resource type so IdPs know which fields they can provision.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
|
||||
|
||||
class ScimResourceType(BaseModel):
|
||||
"""SCIM ResourceType resource (RFC 7643 §6).
|
||||
|
||||
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
endpoint: str
|
||||
description: str | None = None
|
||||
schema_: str = Field(alias="schema")
|
||||
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin API Schemas (Onyx-internal, for SCIM token management)
|
||||
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
|
||||
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimTokenCreate(BaseModel):
|
||||
"""Request to create a new SCIM bearer token."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ScimTokenResponse(BaseModel):
|
||||
"""SCIM token metadata returned in list/get responses."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
"""Response returned when a new SCIM token is created.
|
||||
|
||||
Includes the raw token value which is only available at creation time.
|
||||
"""
|
||||
|
||||
raw_token: str
|
||||
@@ -1,256 +0,0 @@
|
||||
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
|
||||
|
||||
Identity providers use PATCH to make incremental changes to SCIM resources
|
||||
instead of replacing the entire resource with PUT. Common operations include:
|
||||
|
||||
- Deactivating a user: ``replace`` ``active`` with ``false``
|
||||
- Adding group members: ``add`` to ``members``
|
||||
- Removing group members: ``remove`` from ``members[value eq "..."]``
|
||||
|
||||
This module applies PATCH operations to Pydantic SCIM resource objects and
|
||||
returns the modified result. It does NOT touch the database — the caller is
|
||||
responsible for persisting changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
|
||||
def __init__(self, detail: str, status: int = 400) -> None:
|
||||
self.detail = detail
|
||||
self.status = status
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
current_members: list[dict] = list(data.get("members") or [])
|
||||
added_ids: list[str] = []
|
||||
removed_ids: list[str] = []
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_group_remove(op, current_members, removed_ids)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on Group resource"
|
||||
)
|
||||
|
||||
data["members"] = current_members
|
||||
group = ScimGroupResource.model_validate(data)
|
||||
return group, added_ids, removed_ids
|
||||
|
||||
|
||||
def _apply_group_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
removed_ids.extend(old_ids - new_ids)
|
||||
added_ids.extend(new_ids - old_ids)
|
||||
|
||||
current_members[:] = value
|
||||
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
added_ids: list[str],
|
||||
) -> None:
|
||||
"""Add members to a group."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if path and path != "members":
|
||||
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
|
||||
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
added_ids.append(member_id)
|
||||
existing_ids.add(member_id)
|
||||
|
||||
|
||||
def _apply_group_remove(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove members from a group."""
|
||||
if not op.path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
match = _MEMBER_FILTER_RE.match(op.path)
|
||||
if not match:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported remove path '{op.path}'. "
|
||||
'Expected: members[value eq "user-id"]'
|
||||
)
|
||||
|
||||
target_id = match.group(1)
|
||||
original_len = len(members)
|
||||
members[:] = [m for m in members if m.get("value") != target_id]
|
||||
|
||||
if len(members) < original_len:
|
||||
removed_ids.append(target_id)
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
|
||||
|
||||
Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaAttribute
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
|
||||
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
|
||||
|
||||
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/scim/v2/Groups",
|
||||
"description": "SCIM Group resource",
|
||||
"schema": SCIM_GROUP_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_USER_SCHEMA,
|
||||
name="User",
|
||||
description="SCIM core User schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="userName",
|
||||
type="string",
|
||||
required=True,
|
||||
uniqueness="server",
|
||||
description="Unique identifier for the user, typically an email address.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="name",
|
||||
type="complex",
|
||||
description="The components of the user's name.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="givenName",
|
||||
type="string",
|
||||
description="The user's first name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="familyName",
|
||||
type="string",
|
||||
description="The user's last name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="formatted",
|
||||
type="string",
|
||||
description="The full name, including all middle names and titles.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="emails",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Email addresses for the user.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Email address value.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="type",
|
||||
type="string",
|
||||
description="Label for this email (e.g. 'work').",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="primary",
|
||||
type="boolean",
|
||||
description="Whether this is the primary email.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="active",
|
||||
type="boolean",
|
||||
description="Whether the user account is active.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
description="SCIM core Group schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="displayName",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Human-readable name for the group.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="members",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Members of the group.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="User ID of the group member.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="display",
|
||||
type="string",
|
||||
mutability="readOnly",
|
||||
description="Display name of the group member.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1,13 +1,10 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -44,14 +41,6 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -93,18 +82,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -113,7 +90,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license found in cache or DB.
|
||||
# No license found.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
|
||||
@@ -37,15 +37,12 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
@@ -43,21 +41,8 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
captcha_token: str | None = None
|
||||
|
||||
@override
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
@@ -121,7 +121,6 @@ from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -138,8 +137,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
return user.role == UserRole.ADMIN
|
||||
@@ -211,34 +208,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def workspace_invite_only_enabled() -> bool:
|
||||
settings = load_settings()
|
||||
return settings.invite_only_enabled
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
|
||||
# SSO providers manage membership; allow JIT provisioning regardless of invites
|
||||
return
|
||||
|
||||
if not workspace_invite_only_enabled():
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -255,13 +240,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
)
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
@@ -1671,10 +1650,7 @@ def get_oauth_router(
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
# Use WEB_DOMAIN instead of request.url_for() to prevent host
|
||||
# header poisoning — request.url_for() trusts the Host header.
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
|
||||
@@ -1,30 +1,25 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.connector_runner import batched_doc_ids
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -34,129 +29,63 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
|
||||
doc_ids: set[str]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _checkpointed_batched_items(
|
||||
connector: CheckpointedConnector[CT],
|
||||
start: float,
|
||||
end: float,
|
||||
) -> Generator[list[Document | HierarchyNode | ConnectorFailure], None, None]:
|
||||
"""Loop through all checkpoint steps and yield batched items.
|
||||
|
||||
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
|
||||
checkpoint call may only initialize internal state without yielding
|
||||
any documents. This function loops until checkpoint.has_more is False
|
||||
to ensure all items are collected across every step.
|
||||
"""
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
while True:
|
||||
checkpoint_output = connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
|
||||
batch: list[Document | HierarchyNode | ConnectorFailure] = []
|
||||
for document, hierarchy_node, failure, next_checkpoint in wrapper(
|
||||
checkpoint_output
|
||||
):
|
||||
if document is not None:
|
||||
batch.append(document)
|
||||
elif hierarchy_node is not None:
|
||||
batch.append(hierarchy_node)
|
||||
elif failure is not None:
|
||||
batch.append(failure)
|
||||
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
if not checkpoint.has_more:
|
||||
break
|
||||
|
||||
|
||||
def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
"""Extract the document/entity ID from a ConnectorFailure."""
|
||||
if failure.failed_document:
|
||||
return failure.failed_document.document_id
|
||||
if failure.failed_entity:
|
||||
return failure.failed_entity.entity_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids.add(failed_id)
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
def document_batch_to_ids(
|
||||
doc_batch: (
|
||||
Iterator[list[Document | HierarchyNode]]
|
||||
| Iterator[list[SlimDocument | HierarchyNode]]
|
||||
),
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {
|
||||
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
|
||||
for doc in doc_list
|
||||
}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> SlimConnectorExtractionResult:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Extract document IDs and hierarchy nodes from a runnable connector.
|
||||
|
||||
Hierarchy nodes yielded alongside documents/slim docs are collected and
|
||||
returned in the result. ConnectorFailure items have their IDs preserved
|
||||
so that failed-to-retrieve documents are not accidentally pruned.
|
||||
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
raw_batch_generator: (
|
||||
Iterator[Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure]]
|
||||
| None
|
||||
) = None
|
||||
|
||||
doc_batch_id_generator = None
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
raw_batch_generator = runnable_connector.retrieve_all_slim_docs()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs()
|
||||
)
|
||||
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
|
||||
raw_batch_generator = runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
)
|
||||
# If the connector isn't slim, fall back to running it normally to get ids
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
raw_batch_generator = runnable_connector.load_from_state()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
raw_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.poll_source(start=start, end=end)
|
||||
)
|
||||
elif isinstance(runnable_connector, CheckpointedConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
raw_batch_generator = _checkpointed_batched_items(
|
||||
runnable_connector, start, end
|
||||
checkpoint = runnable_connector.build_dummy_checkpoint()
|
||||
checkpoint_generator = runnable_connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
doc_batch_id_generator = batched_doc_ids(
|
||||
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
@@ -170,24 +99,19 @@ def extract_ids_from_runnable_connector(
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
# process raw batches to extract both IDs and hierarchy nodes
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
doc_ids=all_connector_doc_ids,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
|
||||
@@ -37,7 +37,6 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
@@ -52,29 +51,11 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
if bool(self.redis_connector.stop.fenced):
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"IndexingCallback Docprocessing - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
"""Amount isn't used yet."""
|
||||
|
||||
@@ -146,26 +146,14 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
|
||||
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
|
||||
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
"celery_queue_length": "celery",
|
||||
"docprocessing_queue_length": "docprocessing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
|
||||
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
|
||||
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
|
||||
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
|
||||
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
|
||||
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
|
||||
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
@@ -893,7 +881,7 @@ def monitor_celery_queues_helper(
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
@@ -920,26 +908,6 @@ def monitor_celery_queues_helper(
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
n_hierarchy_fetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
|
||||
)
|
||||
n_llm_model_update = celery_get_queue_length(
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
|
||||
)
|
||||
n_checkpoint_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
|
||||
)
|
||||
n_index_attempt_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
|
||||
)
|
||||
n_csv_generation = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CSV_GENERATION, r_celery
|
||||
)
|
||||
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
|
||||
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
|
||||
n_opensearch_migration = celery_get_queue_length(
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
|
||||
)
|
||||
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -963,14 +931,6 @@ def monitor_celery_queues_helper(
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
f"hierarchy_fetching={n_hierarchy_fetching} "
|
||||
f"llm_model_update={n_llm_model_update} "
|
||||
f"checkpoint_cleanup={n_checkpoint_cleanup} "
|
||||
f"index_attempt_cleanup={n_index_attempt_cleanup} "
|
||||
f"csv_generation={n_csv_generation} "
|
||||
f"monitoring={n_monitoring} "
|
||||
f"sandbox={n_sandbox} "
|
||||
f"opensearch_migration={n_opensearch_migration} "
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,14 +41,3 @@ assert (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
|
||||
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
|
||||
# WARNING: Do not change these values without knowing what changes also need to
|
||||
# be made to OpenSearchTenantMigrationRecord.
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 500
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT = 4
|
||||
|
||||
# String used to indicate in the vespa_visit_continuation_token mapping that the
|
||||
# slice has finished and there is nothing left to visit.
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
|
||||
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
|
||||
)
|
||||
|
||||
@@ -8,12 +8,6 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
@@ -53,13 +47,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
def is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map: dict[int, str | None],
|
||||
) -> bool:
|
||||
return all(
|
||||
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
|
||||
for continuation_token in continuation_token_map.values()
|
||||
)
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@@ -88,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
|
||||
per-document), transform them, and index them into OpenSearch. Progress is
|
||||
tracked via a continuation token map stored in the
|
||||
tracked via a continuation token stored in the
|
||||
OpenSearchTenantMigrationRecord.
|
||||
|
||||
The first time we see no continuation token map and non-zero chunks
|
||||
migrated, we consider the migration complete and all subsequent invocations
|
||||
are no-ops.
|
||||
|
||||
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
|
||||
where progress is tracked for each slice.
|
||||
The first time we see no continuation token and non-zero chunks migrated, we
|
||||
consider the migration complete and all subsequent invocations are no-ops.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -169,28 +153,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token_map,
|
||||
continuation_token,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
if continuation_token is None and total_chunks_migrated > 0:
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
@@ -199,19 +170,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
f"Continuation token: {continuation_token}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
raw_vespa_chunks, next_continuation_token = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
continuation_token_map=continuation_token_map,
|
||||
continuation_token=continuation_token,
|
||||
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
f"seconds. Next continuation token: {next_continuation_token}"
|
||||
)
|
||||
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
@@ -241,11 +212,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
continuation_token=next_continuation_token,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
|
||||
task_logger.info("Vespa reported no more chunks to migrate.")
|
||||
break
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
|
||||
@@ -37,35 +37,6 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
DOCUMENT_ID,
|
||||
CHUNK_ID,
|
||||
TITLE,
|
||||
TITLE_EMBEDDING,
|
||||
CONTENT,
|
||||
EMBEDDINGS,
|
||||
SOURCE_TYPE,
|
||||
METADATA_LIST,
|
||||
DOC_UPDATED_AT,
|
||||
HIDDEN,
|
||||
BOOST,
|
||||
SEMANTIC_IDENTIFIER,
|
||||
IMAGE_FILE_NAME,
|
||||
SOURCE_LINKS,
|
||||
BLURB,
|
||||
DOC_SUMMARY,
|
||||
CHUNK_CONTEXT,
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
]
|
||||
if MULTI_TENANT:
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
|
||||
|
||||
|
||||
def _extract_content_vector(embeddings: Any) -> list[float]:
|
||||
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.
|
||||
|
||||
|
||||
@@ -43,11 +43,9 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
@@ -55,9 +53,6 @@ from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
@@ -528,47 +523,12 @@ def connector_pruning_generator_task(
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=cc_pair.connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=cc_pair.connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning: persisted and cached {len(extraction_result.hierarchy_nodes)} "
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -373,12 +304,6 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -9,8 +9,10 @@ Summaries are stored as `ChatMessage` records with two key fields:
|
||||
- `parent_message_id` → last message when compression triggered (places summary in the tree)
|
||||
- `last_summarized_message_id` → pointer to an older message up the chain (the cutoff). Messages after this are kept verbatim.
|
||||
|
||||
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred. It's only back-pointed to by the
|
||||
branch which it applies to. All of this is necessary because we keep the last few messages verbatim and also to support branching logic.
|
||||
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred.
|
||||
|
||||
### Timestamp-Based Ordering
|
||||
Messages are filtered by `time_sent` (not ID) so the logic remains intact if IDs are changed to UUIDs in the future.
|
||||
|
||||
### Progressive Summarization
|
||||
Subsequent compressions incorporate the existing summary text + new messages, preventing information loss in very long conversations.
|
||||
@@ -24,11 +26,10 @@ Context window breakdown:
|
||||
- `max_context_tokens` — LLM's total context window
|
||||
- `reserved_tokens` — space for system prompt, tools, files, etc.
|
||||
- Available for chat history = `max_context_tokens - reserved_tokens`
|
||||
Note: If there is a lot of reserved tokens, chat compression may happen fairly frequently which is costly, slow, and leads to a bad user experience. Possible area of future improvement.
|
||||
|
||||
Configurable ratios:
|
||||
- `COMPRESSION_TRIGGER_RATIO` (default 0.75) — compress when chat history exceeds this ratio of available space
|
||||
- `RECENT_MESSAGES_RATIO` (default 0.2) — portion of chat history to keep verbatim when compressing
|
||||
- `RECENT_MESSAGES_RATIO` (default 0.25) — portion of chat history to keep verbatim when compressing
|
||||
|
||||
## Flow
|
||||
|
||||
|
||||
@@ -1,29 +1,36 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import check_project_ownership
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -40,13 +47,15 @@ from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
@@ -269,6 +278,70 @@ def extract_headers(
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
# Use the first prompt from the override config for embedded prompt fields
|
||||
first_prompt = persona_config.prompts[0]
|
||||
persona.system_prompt = first_prompt.system_prompt
|
||||
persona.task_prompt = first_prompt.task_prompt
|
||||
persona.datetime_aware = first_prompt.datetime_aware
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=schema,
|
||||
emitter=get_default_emitter(),
|
||||
),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def process_kg_commands(
|
||||
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
|
||||
) -> None:
|
||||
@@ -424,40 +497,6 @@ def convert_chat_history_basic(
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def _build_tool_call_response_history_message(
|
||||
tool_name: str,
|
||||
generated_images: list[dict] | None,
|
||||
tool_call_response: str | None,
|
||||
) -> str:
|
||||
if tool_name != IMAGE_GENERATION_TOOL_NAME:
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
if generated_images:
|
||||
llm_image_context: list[dict[str, str]] = []
|
||||
for image in generated_images:
|
||||
file_id = image.get("file_id")
|
||||
revised_prompt = image.get("revised_prompt")
|
||||
if not isinstance(file_id, str):
|
||||
continue
|
||||
|
||||
llm_image_context.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"revised_prompt": (
|
||||
revised_prompt if isinstance(revised_prompt, str) else ""
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if llm_image_context:
|
||||
return json.dumps(llm_image_context)
|
||||
|
||||
if tool_call_response:
|
||||
return tool_call_response
|
||||
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
@@ -618,24 +657,10 @@ def convert_chat_history(
|
||||
|
||||
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
|
||||
for tool_call in turn_tool_calls:
|
||||
tool_name = tool_id_to_name_map.get(
|
||||
tool_call.tool_id, "unknown"
|
||||
)
|
||||
tool_response_message = (
|
||||
_build_tool_call_response_history_message(
|
||||
tool_name=tool_name,
|
||||
generated_images=tool_call.generated_images,
|
||||
tool_call_response=tool_call.tool_call_response,
|
||||
)
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=(
|
||||
token_counter(tool_response_message)
|
||||
if tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
else 20
|
||||
),
|
||||
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
|
||||
token_count=20, # Tiny overestimate
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
@@ -663,34 +688,28 @@ def convert_chat_history(
|
||||
|
||||
|
||||
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
|
||||
"""Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt,
|
||||
it does not count as a custom agent prompt (logic exists later also to drop it in this case).
|
||||
"""Get the custom agent prompt from persona or project instructions.
|
||||
|
||||
Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt.
|
||||
Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions
|
||||
|
||||
# NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts
|
||||
# we never want to return an empty string for a prompt so it's translated into an explicit None.
|
||||
Priority: persona.system_prompt > chat_session.project.instructions > None
|
||||
|
||||
Args:
|
||||
persona: The Persona object
|
||||
chat_session: The ChatSession object
|
||||
|
||||
Returns:
|
||||
The prompt to use for the custom Agent part of the prompt.
|
||||
The custom agent prompt string, or None if neither persona nor project has one
|
||||
"""
|
||||
# If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt.
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
# Logic exists later also to drop it in this case but this is strictly correct anyhow.
|
||||
if persona.replace_base_system_prompt:
|
||||
return None
|
||||
return persona.system_prompt or None
|
||||
# Not considered a custom agent if it's the default behavior persona
|
||||
if persona.id == DEFAULT_PERSONA_ID:
|
||||
return None
|
||||
|
||||
# If in a project and using the default Agent, respect the project instructions.
|
||||
if chat_session.project and chat_session.project.instructions:
|
||||
if persona.system_prompt:
|
||||
return persona.system_prompt
|
||||
elif chat_session.project and chat_session.project.instructions:
|
||||
return chat_session.project.instructions
|
||||
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:
|
||||
|
||||
@@ -17,26 +17,20 @@ from onyx.configs.chat_configs import COMPRESSION_TRIGGER_RATIO
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK
|
||||
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_PROMPT
|
||||
from onyx.prompts.compression_prompts import PROGRESSIVE_USER_REMINDER
|
||||
from onyx.prompts.compression_prompts import SUMMARIZATION_CUTOFF_MARKER
|
||||
from onyx.prompts.compression_prompts import SUMMARIZATION_PROMPT
|
||||
from onyx.prompts.compression_prompts import USER_REMINDER
|
||||
from onyx.tracing.framework.create import ensure_trace
|
||||
from onyx.tracing.llm_utils import llm_generation_span
|
||||
from onyx.tracing.llm_utils import record_llm_response
|
||||
from onyx.prompts.compression_prompts import USER_FINAL_REMINDER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Ratio of available context to allocate for recent messages after compression
|
||||
RECENT_MESSAGES_RATIO = 0.2
|
||||
RECENT_MESSAGES_RATIO = 0.25
|
||||
|
||||
|
||||
class CompressionResult(BaseModel):
|
||||
@@ -193,11 +187,6 @@ def get_messages_to_summarize(
|
||||
recent_messages.insert(0, msg)
|
||||
tokens_used += msg_tokens
|
||||
|
||||
# Ensure cutoff is right before a user message by moving any leading
|
||||
# non-user messages from recent_messages to older_messages
|
||||
while recent_messages and recent_messages[0].message_type != MessageType.USER:
|
||||
recent_messages.pop(0)
|
||||
|
||||
# Everything else gets summarized
|
||||
recent_ids = {m.id for m in recent_messages}
|
||||
older_messages = [m for m in messages if m.id not in recent_ids]
|
||||
@@ -207,47 +196,31 @@ def get_messages_to_summarize(
|
||||
)
|
||||
|
||||
|
||||
def _build_llm_messages_for_summarization(
|
||||
def format_messages_for_summary(
|
||||
messages: list[ChatMessage],
|
||||
tool_id_to_name: dict[int, str],
|
||||
) -> list[UserMessage | AssistantMessage]:
|
||||
"""Convert ChatMessage objects to LLM message format for summarization.
|
||||
) -> str:
|
||||
"""Format messages into a string for the summarization prompt.
|
||||
|
||||
This is intentionally different from translate_history_to_llm_format in llm_step.py:
|
||||
- Compacts tool calls to "[Used tools: tool1, tool2]" to save tokens in summaries
|
||||
- Skips TOOL_CALL_RESPONSE messages entirely (tool usage captured in assistant message)
|
||||
- No image/multimodal handling (summaries are text-only)
|
||||
- No caching or LLMConfig-specific behavior needed
|
||||
Tool call messages are formatted compactly to save tokens.
|
||||
"""
|
||||
result: list[UserMessage | AssistantMessage] = []
|
||||
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
# Skip empty messages
|
||||
if not msg.message:
|
||||
continue
|
||||
|
||||
# Handle assistant messages with tool calls compactly
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
if msg.tool_calls:
|
||||
tool_names = [
|
||||
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
|
||||
]
|
||||
result.append(
|
||||
AssistantMessage(content=f"[Used tools: {', '.join(tool_names)}]")
|
||||
)
|
||||
else:
|
||||
result.append(AssistantMessage(content=msg.message))
|
||||
# Format assistant messages with tool calls compactly
|
||||
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
|
||||
tool_names = [
|
||||
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
|
||||
]
|
||||
formatted.append(f"[assistant used tools: {', '.join(tool_names)}]")
|
||||
continue
|
||||
|
||||
# Skip tool call response messages - tool calls are captured above via assistant messages
|
||||
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
continue
|
||||
|
||||
# Handle user messages
|
||||
if msg.message_type == MessageType.USER:
|
||||
result.append(UserMessage(content=msg.message))
|
||||
|
||||
return result
|
||||
role = msg.message_type.value
|
||||
formatted.append(f"[{role}]: {msg.message}")
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
|
||||
def generate_summary(
|
||||
@@ -263,9 +236,6 @@ def generate_summary(
|
||||
The cutoff marker tells the LLM to summarize only older messages,
|
||||
while using recent messages as context to inform what's important.
|
||||
|
||||
Messages are sent as separate UserMessage/AssistantMessage objects rather
|
||||
than being concatenated into a single message.
|
||||
|
||||
Args:
|
||||
older_messages: Messages to compress into summary (before cutoff)
|
||||
recent_messages: Messages kept verbatim (after cutoff, for context only)
|
||||
@@ -276,54 +246,37 @@ def generate_summary(
|
||||
Returns:
|
||||
Summary text
|
||||
"""
|
||||
# Build system prompt
|
||||
system_content = SUMMARIZATION_PROMPT
|
||||
older_messages_str = format_messages_for_summary(older_messages, tool_id_to_name)
|
||||
recent_messages_str = format_messages_for_summary(recent_messages, tool_id_to_name)
|
||||
|
||||
# Build user prompt with cutoff marker
|
||||
if existing_summary:
|
||||
# Progressive summarization: append existing summary to system prompt
|
||||
system_content += PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK.format(
|
||||
previous_summary=existing_summary
|
||||
# Progressive summarization: include existing summary
|
||||
user_prompt = PROGRESSIVE_SUMMARY_PROMPT.format(
|
||||
existing_summary=existing_summary
|
||||
)
|
||||
user_prompt += f"\n\n{older_messages_str}"
|
||||
final_reminder = PROGRESSIVE_USER_REMINDER
|
||||
else:
|
||||
final_reminder = USER_REMINDER
|
||||
# Initial summarization
|
||||
user_prompt = older_messages_str
|
||||
final_reminder = USER_FINAL_REMINDER
|
||||
|
||||
# Convert messages to LLM format (using compression-specific conversion)
|
||||
older_llm_messages = _build_llm_messages_for_summarization(
|
||||
older_messages, tool_id_to_name
|
||||
)
|
||||
recent_llm_messages = _build_llm_messages_for_summarization(
|
||||
recent_messages, tool_id_to_name
|
||||
)
|
||||
|
||||
# Build message list with separate messages
|
||||
input_messages: list[ChatCompletionMessage] = [
|
||||
SystemMessage(content=system_content),
|
||||
]
|
||||
|
||||
# Add older messages (to be summarized)
|
||||
input_messages.extend(older_llm_messages)
|
||||
|
||||
# Add cutoff marker as a user message
|
||||
input_messages.append(UserMessage(content=SUMMARIZATION_CUTOFF_MARKER))
|
||||
|
||||
# Add recent messages (for context only)
|
||||
input_messages.extend(recent_llm_messages)
|
||||
# Add cutoff marker and recent messages as context
|
||||
user_prompt += f"\n\n{SUMMARIZATION_CUTOFF_MARKER}"
|
||||
if recent_messages_str:
|
||||
user_prompt += f"\n\n{recent_messages_str}"
|
||||
|
||||
# Add final reminder
|
||||
input_messages.append(UserMessage(content=final_reminder))
|
||||
user_prompt += f"\n\n{final_reminder}"
|
||||
|
||||
with llm_generation_span(
|
||||
llm=llm,
|
||||
flow="chat_history_summarization",
|
||||
input_messages=input_messages,
|
||||
) as span_generation:
|
||||
response = llm.invoke(input_messages)
|
||||
record_llm_response(span_generation, response)
|
||||
|
||||
content = response.choice.message.content
|
||||
if not (content and content.strip()):
|
||||
raise ValueError("LLM returned empty summary")
|
||||
return content.strip()
|
||||
response = llm.invoke(
|
||||
[
|
||||
SystemMessage(content=SUMMARIZATION_PROMPT),
|
||||
UserMessage(content=user_prompt),
|
||||
]
|
||||
)
|
||||
return response.choice.message.content or ""
|
||||
|
||||
|
||||
def compress_chat_history(
|
||||
@@ -339,19 +292,6 @@ def compress_chat_history(
|
||||
The summary message's parent_message_id points to the last message in
|
||||
chat_history, making it branch-aware via the tree structure.
|
||||
|
||||
Note: This takes the entire chat history as input, splits it into older
|
||||
messages (to summarize) and recent messages (kept verbatim within the
|
||||
token budget), generates a summary of the older part, and persists the
|
||||
new summary message with its parent set to the last message in history.
|
||||
|
||||
Past summary is taken into context (progressive summarization): we find
|
||||
at most one existing summary for this branch. If present, only messages
|
||||
after that summary's last_summarized_message_id are considered; the
|
||||
existing summary text is passed into the LLM so the new summary
|
||||
incorporates it instead of summarizing from scratch.
|
||||
|
||||
For more details, see the COMPRESSION.md file.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
chat_history: Branch-aware list of messages
|
||||
@@ -365,84 +305,74 @@ def compress_chat_history(
|
||||
if not chat_history:
|
||||
return CompressionResult(summary_created=False, messages_summarized=0)
|
||||
|
||||
chat_session_id = chat_history[0].chat_session_id
|
||||
|
||||
logger.info(
|
||||
f"Starting compression for session {chat_session_id}, "
|
||||
f"Starting compression for session {chat_history[0].chat_session_id}, "
|
||||
f"history_len={len(chat_history)}, tokens_for_recent={compression_params.tokens_for_recent}"
|
||||
)
|
||||
|
||||
with ensure_trace(
|
||||
"chat_history_compression",
|
||||
group_id=str(chat_session_id),
|
||||
metadata={
|
||||
"tenant_id": get_current_tenant_id(),
|
||||
"chat_session_id": str(chat_session_id),
|
||||
},
|
||||
):
|
||||
try:
|
||||
# Find existing summary for this branch
|
||||
existing_summary = find_summary_for_branch(db_session, chat_history)
|
||||
try:
|
||||
# Find existing summary for this branch
|
||||
existing_summary = find_summary_for_branch(db_session, chat_history)
|
||||
|
||||
# Get messages to summarize
|
||||
summary_content = get_messages_to_summarize(
|
||||
chat_history,
|
||||
existing_summary,
|
||||
tokens_for_recent=compression_params.tokens_for_recent,
|
||||
)
|
||||
# Get messages to summarize
|
||||
summary_content = get_messages_to_summarize(
|
||||
chat_history,
|
||||
existing_summary,
|
||||
tokens_for_recent=compression_params.tokens_for_recent,
|
||||
)
|
||||
|
||||
if not summary_content.older_messages:
|
||||
logger.debug("No messages to summarize, skipping compression")
|
||||
return CompressionResult(summary_created=False, messages_summarized=0)
|
||||
if not summary_content.older_messages:
|
||||
logger.debug("No messages to summarize, skipping compression")
|
||||
return CompressionResult(summary_created=False, messages_summarized=0)
|
||||
|
||||
# Generate summary (incorporate existing summary if present)
|
||||
existing_summary_text = (
|
||||
existing_summary.message if existing_summary else None
|
||||
)
|
||||
summary_text = generate_summary(
|
||||
older_messages=summary_content.older_messages,
|
||||
recent_messages=summary_content.recent_messages,
|
||||
llm=llm,
|
||||
tool_id_to_name=tool_id_to_name,
|
||||
existing_summary=existing_summary_text,
|
||||
)
|
||||
# Generate summary (incorporate existing summary if present)
|
||||
existing_summary_text = existing_summary.message if existing_summary else None
|
||||
summary_text = generate_summary(
|
||||
older_messages=summary_content.older_messages,
|
||||
recent_messages=summary_content.recent_messages,
|
||||
llm=llm,
|
||||
tool_id_to_name=tool_id_to_name,
|
||||
existing_summary=existing_summary_text,
|
||||
)
|
||||
|
||||
# Calculate token count for the summary
|
||||
tokenizer = get_tokenizer(None, None)
|
||||
summary_token_count = len(tokenizer.encode(summary_text))
|
||||
logger.debug(
|
||||
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
|
||||
)
|
||||
# Calculate token count for the summary
|
||||
tokenizer = get_tokenizer(None, None)
|
||||
summary_token_count = len(tokenizer.encode(summary_text))
|
||||
logger.debug(
|
||||
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
|
||||
)
|
||||
|
||||
# Create new summary as a ChatMessage
|
||||
# Parent is the last message in history - this makes the summary branch-aware
|
||||
summary_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
message=summary_text,
|
||||
token_count=summary_token_count,
|
||||
parent_message_id=chat_history[-1].id,
|
||||
last_summarized_message_id=summary_content.older_messages[-1].id,
|
||||
)
|
||||
db_session.add(summary_message)
|
||||
db_session.commit()
|
||||
# Create new summary as a ChatMessage
|
||||
# Parent is the last message in history - this makes the summary branch-aware
|
||||
summary_message = ChatMessage(
|
||||
chat_session_id=chat_history[0].chat_session_id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
message=summary_text,
|
||||
token_count=summary_token_count,
|
||||
parent_message_id=chat_history[-1].id,
|
||||
last_summarized_message_id=summary_content.older_messages[-1].id,
|
||||
)
|
||||
db_session.add(summary_message)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Compressed {len(summary_content.older_messages)} messages into summary "
|
||||
f"(session_id={chat_session_id}, "
|
||||
f"summary_tokens={summary_token_count})"
|
||||
)
|
||||
logger.info(
|
||||
f"Compressed {len(summary_content.older_messages)} messages into summary "
|
||||
f"(session_id={chat_history[0].chat_session_id}, "
|
||||
f"summary_tokens={summary_token_count})"
|
||||
)
|
||||
|
||||
return CompressionResult(
|
||||
summary_created=True,
|
||||
messages_summarized=len(summary_content.older_messages),
|
||||
)
|
||||
return CompressionResult(
|
||||
summary_created=True,
|
||||
messages_summarized=len(summary_content.older_messages),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Compression failed for session {chat_session_id}: {e}")
|
||||
db_session.rollback()
|
||||
return CompressionResult(
|
||||
summary_created=False,
|
||||
messages_summarized=0,
|
||||
error=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Compression failed for session {chat_history[0].chat_session_id}: {e}"
|
||||
)
|
||||
db_session.rollback()
|
||||
return CompressionResult(
|
||||
summary_created=False,
|
||||
messages_summarized=0,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -48,7 +49,6 @@ from onyx.server.query_and_chat.streaming_models import TopLevelBranching
|
||||
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -57,7 +57,6 @@ from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -69,18 +68,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
"""Detect XML-style marshaled tool calls emitted as plain text."""
|
||||
if not text:
|
||||
return False
|
||||
lowered = text.lower()
|
||||
return (
|
||||
"<function_calls" in lowered
|
||||
and "<invoke" in lowered
|
||||
and "<parameter" in lowered
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
@@ -135,56 +122,38 @@ def _try_fallback_tool_extraction(
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
xml_tool_call_text_detected = no_tool_calls and (
|
||||
_looks_like_xml_tool_call_payload(llm_step_result.answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
|
||||
)
|
||||
should_try_fallback = (
|
||||
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
|
||||
or reasoning_but_no_answer_or_tools
|
||||
or xml_tool_call_text_detected
|
||||
)
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if (
|
||||
not extracted_tool_calls
|
||||
and llm_step_result.raw_answer
|
||||
and llm_step_result.raw_answer != llm_step_result.answer
|
||||
):
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.raw_answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
"as fallback"
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
raw_answer=llm_step_result.raw_answer,
|
||||
),
|
||||
True,
|
||||
)
|
||||
@@ -482,42 +451,7 @@ def construct_message_history(
|
||||
if reminder_message:
|
||||
result.append(reminder_message)
|
||||
|
||||
return _drop_orphaned_tool_call_responses(result)
|
||||
|
||||
|
||||
def _drop_orphaned_tool_call_responses(
|
||||
messages: list[ChatMessageSimple],
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Drop tool response messages whose tool_call_id is not in prior assistant tool calls.
|
||||
|
||||
This can happen when history truncation drops an ASSISTANT tool-call message but
|
||||
leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama)
|
||||
reject such history with an "unexpected tool call id" error.
|
||||
"""
|
||||
known_tool_call_ids: set[str] = set()
|
||||
sanitized: list[ChatMessageSimple] = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
known_tool_call_ids.add(tool_call.tool_call_id)
|
||||
sanitized.append(msg)
|
||||
continue
|
||||
|
||||
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids:
|
||||
sanitized.append(msg)
|
||||
else:
|
||||
logger.debug(
|
||||
"Dropping orphaned tool response with tool_call_id=%s while "
|
||||
"constructing message history",
|
||||
msg.tool_call_id,
|
||||
)
|
||||
continue
|
||||
|
||||
sanitized.append(msg)
|
||||
|
||||
return sanitized
|
||||
return result
|
||||
|
||||
|
||||
def _create_file_tool_metadata_message(
|
||||
@@ -592,7 +526,6 @@ def run_llm_loop(
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
chat_files: list[ChatFile] | None = None,
|
||||
include_citations: bool = True,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
inject_memories_in_prompt: bool = True,
|
||||
@@ -652,7 +585,6 @@ def run_llm_loop(
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
code_interpreter_file_generated: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
@@ -662,7 +594,6 @@ def run_llm_loop(
|
||||
|
||||
reasoning_cycles = 0
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
# Handling tool calls based on cycle count and past cycle conditions
|
||||
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
@@ -684,7 +615,6 @@ def run_llm_loop(
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
|
||||
# Handling the system prompt and custom agent prompt
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
# now that project files are loaded in.
|
||||
if persona and persona.replace_base_system_prompt:
|
||||
@@ -702,14 +632,12 @@ def run_llm_loop(
|
||||
else:
|
||||
# If it's an empty string, we assume the user does not want to include it as an empty System message
|
||||
if default_base_system_prompt:
|
||||
open_ai_formatting_enabled = model_needs_formatting_reenabled(
|
||||
llm.config.model_name
|
||||
)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if inject_memories_in_prompt
|
||||
else (
|
||||
user_memory_context.without_memories()
|
||||
if user_memory_context
|
||||
else None
|
||||
)
|
||||
user_memory_context if inject_memories_in_prompt else None
|
||||
)
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
@@ -718,6 +646,7 @@ def run_llm_loop(
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
open_ai_formatting_enabled=open_ai_formatting_enabled,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=system_prompt_str,
|
||||
@@ -763,7 +692,6 @@ def run_llm_loop(
|
||||
),
|
||||
include_citation_reminder=should_cite_documents
|
||||
or always_cite_documents,
|
||||
include_file_reminder=code_interpreter_file_generated,
|
||||
is_last_cycle=out_of_cycles,
|
||||
)
|
||||
|
||||
@@ -876,7 +804,6 @@ def run_llm_loop(
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
chat_files=chat_files,
|
||||
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
|
||||
inject_memories_in_prompt=inject_memories_in_prompt,
|
||||
)
|
||||
@@ -903,18 +830,6 @@ def run_llm_loop(
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
has_called_search_tool = True
|
||||
|
||||
# Track if code interpreter generated files with download links
|
||||
if (
|
||||
tool_call.tool_name == PythonTool.NAME
|
||||
and not code_interpreter_file_generated
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(tool_response.llm_facing_response)
|
||||
if parsed.get("generated_files"):
|
||||
code_interpreter_file_generated = True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from html import unescape
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -20,7 +18,6 @@ from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
@@ -39,8 +36,6 @@ from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -59,112 +54,6 @@ from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_XML_INVOKE_BLOCK_RE = re.compile(
|
||||
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_XML_PARAMETER_RE = re.compile(
|
||||
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
|
||||
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
|
||||
|
||||
|
||||
class _XmlToolCallContentFilter:
|
||||
"""Streaming filter that strips XML-style tool call payload blocks from text."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
|
||||
def process(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
self._pending += content
|
||||
output_parts: list[str] = []
|
||||
|
||||
while self._pending:
|
||||
pending_lower = self._pending.lower()
|
||||
|
||||
if self._inside_function_calls_block:
|
||||
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
|
||||
if end_idx == -1:
|
||||
# Keep buffering until we see the close marker.
|
||||
return "".join(output_parts)
|
||||
|
||||
# Drop the whole function_calls block.
|
||||
self._pending = self._pending[
|
||||
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
|
||||
]
|
||||
self._inside_function_calls_block = False
|
||||
continue
|
||||
|
||||
start_idx = _find_function_calls_open_marker(pending_lower)
|
||||
if start_idx == -1:
|
||||
# Keep only a possible prefix of "<function_calls" in the buffer so
|
||||
# marker splits across chunks are handled correctly.
|
||||
tail_len = _matching_open_marker_prefix_len(self._pending)
|
||||
emit_upto = len(self._pending) - tail_len
|
||||
if emit_upto > 0:
|
||||
output_parts.append(self._pending[:emit_upto])
|
||||
self._pending = self._pending[emit_upto:]
|
||||
return "".join(output_parts)
|
||||
|
||||
if start_idx > 0:
|
||||
output_parts.append(self._pending[:start_idx])
|
||||
|
||||
# Enter block-stripping mode and keep scanning for close marker.
|
||||
self._pending = self._pending[start_idx:]
|
||||
self._inside_function_calls_block = True
|
||||
|
||||
return "".join(output_parts)
|
||||
|
||||
def flush(self) -> str:
|
||||
if self._inside_function_calls_block:
|
||||
# Drop any incomplete block at stream end.
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
return ""
|
||||
|
||||
remaining = self._pending
|
||||
self._pending = ""
|
||||
return remaining
|
||||
|
||||
|
||||
def _matching_open_marker_prefix_len(text: str) -> int:
|
||||
"""Return longest suffix of text that matches prefix of "<function_calls"."""
|
||||
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
|
||||
text_lower = text.lower()
|
||||
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
|
||||
|
||||
for candidate_len in range(max_len, 0, -1):
|
||||
if text_lower.endswith(marker_lower[:candidate_len]):
|
||||
return candidate_len
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
|
||||
return char is None or char in {">", " ", "\t", "\n", "\r"}
|
||||
|
||||
|
||||
def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
"""Find '<function_calls' with a valid tag boundary follower."""
|
||||
search_from = 0
|
||||
while True:
|
||||
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
|
||||
if idx == -1:
|
||||
return -1
|
||||
|
||||
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
|
||||
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
|
||||
if _is_valid_function_calls_open_follower(follower):
|
||||
return idx
|
||||
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
@@ -381,7 +270,14 @@ def _extract_tool_call_kickoffs(
|
||||
tab_index_calculated = 0
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
try:
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
|
||||
)
|
||||
tool_args = {}
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
@@ -409,9 +305,8 @@ def extract_tool_calls_from_response_text(
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for tool calls embedded
|
||||
in response text (JSON first, then XML-like invoke blocks) that match available
|
||||
tool definitions.
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
@@ -436,9 +331,10 @@ def extract_tool_calls_from_response_text(
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
prev_json_obj: dict[str, Any] | None = None
|
||||
prev_tool_call: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
@@ -466,14 +362,6 @@ def extract_tool_calls_from_response_text(
|
||||
prev_json_obj = json_obj
|
||||
prev_tool_call = matched_tool_call
|
||||
|
||||
# Some providers/models emit XML-style function calls instead of JSON objects.
|
||||
# Keep this as a fallback behind JSON extraction to preserve current behavior.
|
||||
if not matched_tool_calls:
|
||||
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_name_to_def=tool_name_to_def,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
|
||||
tool_calls.append(
|
||||
@@ -496,88 +384,6 @@ def extract_tool_calls_from_response_text(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _extract_xml_tool_calls_from_response_text(
|
||||
response_text: str,
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> list[tuple[str, dict[str, Any]]]:
|
||||
"""Extract XML-style tool calls from response text.
|
||||
|
||||
Supports formats such as:
|
||||
<function_calls>
|
||||
<invoke name="internal_search">
|
||||
<parameter name="queries" string="false">["foo"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
|
||||
invoke_attrs = invoke_match.group("attrs")
|
||||
tool_name = _extract_xml_attribute(invoke_attrs, "name")
|
||||
if not tool_name or tool_name not in tool_name_to_def:
|
||||
continue
|
||||
|
||||
tool_args: dict[str, Any] = {}
|
||||
invoke_body = invoke_match.group("body")
|
||||
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
|
||||
parameter_attrs = parameter_match.group("attrs")
|
||||
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
string_attr = _extract_xml_attribute(parameter_attrs, "string")
|
||||
tool_args[parameter_name] = _parse_xml_parameter_value(
|
||||
raw_value=parameter_match.group("value"),
|
||||
string_attr=string_attr,
|
||||
)
|
||||
|
||||
matched_tool_calls.append((tool_name, tool_args))
|
||||
|
||||
return matched_tool_calls
|
||||
|
||||
|
||||
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
"""Extract a single XML-style attribute value from a tag attribute string."""
|
||||
attr_match = re.search(
|
||||
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
|
||||
attrs,
|
||||
flags=re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Extract and parse an arguments/parameters value from a tool-call-like object.
|
||||
|
||||
Looks for "arguments" or "parameters" keys, handles JSON-string values,
|
||||
and returns a dict if successful, or None otherwise.
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
return None
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
@@ -600,8 +406,13 @@ def _try_match_json_to_tool(
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = _resolve_tool_arguments(json_obj)
|
||||
if arguments is not None:
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
@@ -609,8 +420,13 @@ def _try_match_json_to_tool(
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = _resolve_tool_arguments(func_obj)
|
||||
if arguments is not None:
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
@@ -677,107 +493,6 @@ def _extract_nested_arguments_obj(
|
||||
return None
|
||||
|
||||
|
||||
def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMessage:
|
||||
tool_calls_list: list[ToolCall] | None = None
|
||||
if msg.tool_calls:
|
||||
tool_calls_list = [
|
||||
ToolCall(
|
||||
id=tc.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc.tool_name,
|
||||
arguments=json.dumps(tc.tool_arguments),
|
||||
),
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
|
||||
return AssistantMessage(
|
||||
role="assistant",
|
||||
content=msg.message or None,
|
||||
tool_calls=tool_calls_list,
|
||||
)
|
||||
|
||||
|
||||
def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
"Tool call response message encountered but tool_call_id is not available. "
|
||||
f"Message: {msg}"
|
||||
)
|
||||
|
||||
return ToolMessage(
|
||||
role="tool",
|
||||
content=msg.message,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
|
||||
|
||||
class _HistoryMessageFormatter:
|
||||
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
def format_tool_response_message(
|
||||
self, msg: ChatMessageSimple
|
||||
) -> ToolMessage | UserMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _DefaultHistoryMessageFormatter(_HistoryMessageFormatter):
|
||||
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
|
||||
return _build_structured_assistant_message(msg)
|
||||
|
||||
def format_tool_response_message(self, msg: ChatMessageSimple) -> ToolMessage:
|
||||
return _build_structured_tool_response_message(msg)
|
||||
|
||||
|
||||
class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
|
||||
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
|
||||
if not msg.tool_calls:
|
||||
return _build_structured_assistant_message(msg)
|
||||
|
||||
tool_call_lines = [
|
||||
(
|
||||
f"[Tool Call] name={tc.tool_name} "
|
||||
f"id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
assistant_content = (
|
||||
"\n".join([msg.message, *tool_call_lines])
|
||||
if msg.message
|
||||
else "\n".join(tool_call_lines)
|
||||
)
|
||||
return AssistantMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
"Tool call response message encountered but tool_call_id is not available. "
|
||||
f"Message: {msg}"
|
||||
)
|
||||
|
||||
return UserMessage(
|
||||
role="user",
|
||||
content=f"[Tool Result] id={msg.tool_call_id}\n{msg.message}",
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_HISTORY_MESSAGE_FORMATTER = _DefaultHistoryMessageFormatter()
|
||||
_OLLAMA_HISTORY_MESSAGE_FORMATTER = _OllamaHistoryMessageFormatter()
|
||||
|
||||
|
||||
def _get_history_message_formatter(llm_config: LLMConfig) -> _HistoryMessageFormatter:
|
||||
if llm_config.model_provider == LlmProviderNames.OLLAMA_CHAT:
|
||||
return _OLLAMA_HISTORY_MESSAGE_FORMATTER
|
||||
|
||||
return _DEFAULT_HISTORY_MESSAGE_FORMATTER
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
@@ -788,10 +503,6 @@ def translate_history_to_llm_format(
|
||||
handling different message types and image files for multimodal support.
|
||||
"""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
history_message_formatter = _get_history_message_formatter(llm_config)
|
||||
# Note: cacheability is computed from pre-translation ChatMessageSimple types.
|
||||
# Some providers flatten tool history into plain assistant/user text, so this split
|
||||
# may be less semantically meaningful, but it remains safe and order-preserving.
|
||||
last_cacheable_msg_idx = -1
|
||||
all_previous_msgs_cacheable = True
|
||||
|
||||
@@ -873,27 +584,45 @@ def translate_history_to_llm_format(
|
||||
messages.append(reminder_msg)
|
||||
|
||||
elif msg.message_type == MessageType.ASSISTANT:
|
||||
messages.append(history_message_formatter.format_assistant_message(msg))
|
||||
tool_calls_list: list[ToolCall] | None = None
|
||||
if msg.tool_calls:
|
||||
tool_calls_list = [
|
||||
ToolCall(
|
||||
id=tc.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc.tool_name,
|
||||
arguments=json.dumps(tc.tool_arguments),
|
||||
),
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
]
|
||||
|
||||
assistant_msg = AssistantMessage(
|
||||
role="assistant",
|
||||
content=msg.message or None,
|
||||
tool_calls=tool_calls_list,
|
||||
)
|
||||
messages.append(assistant_msg)
|
||||
|
||||
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
messages.append(history_message_formatter.format_tool_response_message(msg))
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
|
||||
)
|
||||
|
||||
tool_msg = ToolMessage(
|
||||
role="tool",
|
||||
content=msg.message,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
messages.append(tool_msg)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
# Apply model-specific formatting when translating to LLM format (e.g. OpenAI
|
||||
# reasoning models need CODE_BLOCK_MARKDOWN prefix for correct markdown generation)
|
||||
if model_needs_formatting_reenabled(llm_config.model_name):
|
||||
for i, m in enumerate(messages):
|
||||
if isinstance(m, SystemMessage):
|
||||
messages[i] = SystemMessage(
|
||||
role="system",
|
||||
content=CODE_BLOCK_MARKDOWN + m.content,
|
||||
)
|
||||
break
|
||||
|
||||
# prompt caching: rely on should_cache in ChatMessageSimple to
|
||||
# pick the split point for the cacheable prefix and suffix
|
||||
if last_cacheable_msg_idx != -1:
|
||||
@@ -956,8 +685,7 @@ def run_llm_step_pkt_generator(
|
||||
tool_definitions: List of tool definitions available to the LLM.
|
||||
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
|
||||
llm: Language model interface to use for generation.
|
||||
placement: Placement info (turn_index, tab_index, sub_turn_index) for
|
||||
positioning packets in the conversation UI.
|
||||
turn_index: Current turn index in the conversation.
|
||||
state_container: Container for storing chat state (reasoning, answers).
|
||||
citation_processor: Optional processor for extracting and formatting citations
|
||||
from the response. If provided, processes tokens to identify citations.
|
||||
@@ -969,14 +697,7 @@ def run_llm_step_pkt_generator(
|
||||
custom_token_processor: Optional callable that processes each token delta
|
||||
before yielding. Receives (delta, processor_state) and returns
|
||||
(modified_delta, new_processor_state). Can return None for delta to skip.
|
||||
max_tokens: Optional maximum number of tokens for the LLM response.
|
||||
use_existing_tab_index: If True, use the tab_index from placement for all
|
||||
tool calls instead of auto-incrementing.
|
||||
is_deep_research: If True, treat content before tool calls as reasoning
|
||||
when tool_choice is REQUIRED.
|
||||
pre_answer_processing_time: Optional time spent processing before the
|
||||
answer started, recorded in state_container for analytics.
|
||||
timeout_override: Optional timeout override for the LLM call.
|
||||
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
|
||||
|
||||
Yields:
|
||||
Packet: Streaming packets containing:
|
||||
@@ -1002,15 +723,8 @@ def run_llm_step_pkt_generator(
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
def _current_placement() -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
llm_msg_history = translate_history_to_llm_format(history, llm.config)
|
||||
has_reasoned = False
|
||||
has_reasoned = 0
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.debug(
|
||||
@@ -1022,8 +736,6 @@ def run_llm_step_pkt_generator(
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
accumulated_raw_answer = ""
|
||||
xml_tool_call_content_filter = _XmlToolCallContentFilter()
|
||||
|
||||
processor_state: Any = None
|
||||
|
||||
@@ -1039,112 +751,6 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
stream_start_time = time.monotonic()
|
||||
first_action_recorded = False
|
||||
|
||||
def _emit_citation_results(
|
||||
results: Generator[str | CitationInfo, None, None],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Yield packets for citation processor results (str or CitationInfo)."""
|
||||
nonlocal accumulated_answer
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=result,
|
||||
)
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
|
||||
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
|
||||
nonlocal reasoning_start
|
||||
nonlocal has_reasoned
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = True
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
|
||||
nonlocal accumulated_answer
|
||||
nonlocal accumulated_reasoning
|
||||
nonlocal answer_start
|
||||
nonlocal reasoning_start
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
accumulated_reasoning += content_chunk
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningDelta(reasoning=content_chunk),
|
||||
)
|
||||
reasoning_start = True
|
||||
return
|
||||
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
yield from _emit_citation_results(
|
||||
citation_processor.process_token(content_chunk)
|
||||
)
|
||||
else:
|
||||
accumulated_answer += content_chunk
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseDelta(content=content_chunk),
|
||||
)
|
||||
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
@@ -1203,34 +809,152 @@ def run_llm_step_pkt_generator(
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
# Keep raw content for fallback extraction. Display content can be
|
||||
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
|
||||
accumulated_raw_answer += delta.content
|
||||
filtered_content = xml_tool_call_content_filter.process(delta.content)
|
||||
if filtered_content:
|
||||
yield from _emit_content_chunk(filtered_content)
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
# Treat content as reasoning when we know tool calls are coming
|
||||
accumulated_reasoning += delta.content
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.content),
|
||||
)
|
||||
reasoning_start = True
|
||||
else:
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(
|
||||
accumulated_answer
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
yield from _close_reasoning_if_active()
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
if filtered_content_tail:
|
||||
yield from _emit_content_chunk(filtered_content_tail)
|
||||
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
@@ -1286,14 +1010,50 @@ def run_llm_step_pkt_generator(
|
||||
|
||||
# This may happen if the custom token processor is used to modify other packets into reasoning
|
||||
# Then there won't necessarily be anything else to come after the reasoning tokens
|
||||
yield from _close_reasoning_if_active()
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
|
||||
reasoning_start = False
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
# Reasoning is always first so this should use the post-incremented value of turn_index
|
||||
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
|
||||
# as clickable items and will be stripped out instead.
|
||||
if citation_processor:
|
||||
yield from _emit_citation_results(citation_processor.process_token(None))
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
@@ -1315,9 +1075,8 @@ def run_llm_step_pkt_generator(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
|
||||
),
|
||||
has_reasoned,
|
||||
bool(has_reasoned),
|
||||
)
|
||||
|
||||
|
||||
@@ -1372,4 +1131,4 @@ def run_llm_step(
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, has_reasoned = e.value
|
||||
return llm_step_result, has_reasoned
|
||||
return llm_step_result, bool(has_reasoned)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -16,6 +20,54 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class StreamType(Enum):
|
||||
SUB_QUESTIONS = "sub_questions"
|
||||
SUB_ANSWER = "sub_answer"
|
||||
MAIN_ANSWER = "main_answer"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
stream_type: StreamType = StreamType.MAIN_ANSWER
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class OnyxAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
@@ -26,11 +78,23 @@ class StreamingError(BaseModel):
|
||||
details: dict | None = None # Additional context (tool name, model name, etc.)
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
@@ -38,27 +102,71 @@ class ProjectSearchConfig(BaseModel):
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
datetime_aware: bool = True
|
||||
include_citations: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| StreamStopInfo
|
||||
| MessageResponseIDInfo
|
||||
| StreamingError
|
||||
| UserKnowledgeFilePacket
|
||||
| CreateChatSessionID
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Tool call with full details for non-streaming response."""
|
||||
|
||||
tool_name: str
|
||||
tool_arguments: dict[str, Any]
|
||||
tool_result: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# Reasoning that led to the tool call
|
||||
pre_reasoning: str | None = None
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str
|
||||
@@ -71,11 +179,20 @@ class ChatBasicResponse(BaseModel):
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Tool call with full details for non-streaming response."""
|
||||
|
||||
tool_name: str
|
||||
tool_arguments: dict[str, Any]
|
||||
tool_result: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# Reasoning that led to the tool call
|
||||
pre_reasoning: str | None = None
|
||||
|
||||
|
||||
class ChatFullResponse(BaseModel):
|
||||
"""Complete non-streaming response with all available data.
|
||||
NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
"""
|
||||
"""Complete non-streaming response with all available data."""
|
||||
|
||||
# Core response fields
|
||||
answer: str
|
||||
@@ -185,6 +302,3 @@ class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
tool_calls: list[ToolCallKickoff] | None
|
||||
# Raw LLM text before any display-oriented filtering/sanitization.
|
||||
# Used for fallback tool-call extraction when providers emit calls as text.
|
||||
raw_answer: str | None = None
|
||||
|
||||
@@ -4,6 +4,7 @@ An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from contextvars import Token
|
||||
@@ -36,6 +37,7 @@ from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import StreamingError
|
||||
@@ -79,7 +81,8 @@ from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -88,7 +91,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
@@ -169,29 +171,6 @@ def _should_enable_slack_search(
|
||||
)
|
||||
|
||||
|
||||
def _convert_loaded_files_to_chat_files(
|
||||
loaded_files: list[ChatLoadedFile],
|
||||
) -> list[ChatFile]:
|
||||
"""Convert ChatLoadedFile objects to ChatFile for tool usage (e.g., PythonTool).
|
||||
|
||||
Args:
|
||||
loaded_files: List of ChatLoadedFile objects from the chat history
|
||||
|
||||
Returns:
|
||||
List of ChatFile objects that can be passed to tools
|
||||
"""
|
||||
chat_files = []
|
||||
for loaded_file in loaded_files:
|
||||
if len(loaded_file.content) > 0:
|
||||
chat_files.append(
|
||||
ChatFile(
|
||||
filename=loaded_file.filename or f"file_{loaded_file.file_id}",
|
||||
content=loaded_file.content,
|
||||
)
|
||||
)
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
@@ -454,6 +433,7 @@ def handle_stream_message_objects(
|
||||
external_state_container: ChatStateContainer | None = None,
|
||||
) -> AnswerStream:
|
||||
tenant_id = get_current_tenant_id()
|
||||
processing_start_time = time.monotonic()
|
||||
mock_response_token: Token[str | None] | None = None
|
||||
|
||||
llm: LLM | None = None
|
||||
@@ -635,27 +615,16 @@ def handle_stream_message_objects(
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
|
||||
# This is the custom prompt which may come from the Agent or Project. We fetch it earlier because the inner loop
|
||||
# (run_llm_loop and run_deep_research_llm_loop) should not need to be aware of the Chat History in the DB form processed
|
||||
# here, however we need this early for token reservation.
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
# When use_memories is disabled, strip memories from the prompt context
|
||||
# but keep user info/preferences. The full context is still passed
|
||||
# When use_memories is disabled, don't inject memories into the prompt
|
||||
# or count them in token reservation, but still pass the full context
|
||||
# to the LLM loop for memory tool persistence.
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
prompt_memory_context = user_memory_context if user.use_memories else None
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
@@ -757,9 +726,6 @@ def handle_stream_message_objects(
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
|
||||
# Convert loaded files to ChatFile format for tools like PythonTool
|
||||
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
|
||||
|
||||
# TODO Need to think of some way to support selected docs from the sidebar
|
||||
|
||||
# Reserve a message id for the assistant response for frontend to track packets
|
||||
@@ -854,6 +820,7 @@ def handle_stream_message_objects(
|
||||
assistant_message=assistant_response,
|
||||
llm=llm,
|
||||
reserved_tokens=reserved_token_count,
|
||||
processing_start_time=processing_start_time,
|
||||
)
|
||||
|
||||
# The stream generator can resume on a different worker thread after early yields.
|
||||
@@ -910,7 +877,6 @@ def handle_stream_message_objects(
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
@@ -987,6 +953,7 @@ def llm_loop_completion_handle(
|
||||
assistant_message: ChatMessage,
|
||||
llm: LLM,
|
||||
reserved_tokens: int,
|
||||
processing_start_time: float | None = None, # noqa: ARG001
|
||||
) -> None:
|
||||
chat_session_id = assistant_message.chat_session_id
|
||||
|
||||
@@ -1049,6 +1016,68 @@ def llm_loop_completion_handle(
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages. Both of the below are used for Slack
|
||||
# NOTE: is not stored in the database, only passed in to the LLM as context
|
||||
additional_context: str | None = None,
|
||||
# Slack context for federated Slack search
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> AnswerStream:
|
||||
forced_tool_id = (
|
||||
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
|
||||
)
|
||||
if (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
):
|
||||
all_tools = get_tools(db_session)
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
forced_tool_id = search_tool_id
|
||||
|
||||
translated_new_msg_req = SendMessageRequest(
|
||||
message=new_msg_req.message,
|
||||
llm_override=new_msg_req.llm_override,
|
||||
mock_llm_response=new_msg_req.mock_llm_response,
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
forced_tool_id=forced_tool_id,
|
||||
file_descriptors=new_msg_req.file_descriptors,
|
||||
internal_search_filters=(
|
||||
new_msg_req.retrieval_options.filters
|
||||
if new_msg_req.retrieval_options
|
||||
else None
|
||||
),
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
bypass_acl=bypass_acl,
|
||||
additional_context=additional_context,
|
||||
slack_context=slack_context,
|
||||
)
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
|
||||
|
||||
|
||||
@@ -9,14 +9,13 @@ from onyx.db.persona import get_default_behavior_persona
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import FILE_REMINDER
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.prompt_utils import replace_reminder_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
@@ -26,12 +25,7 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import TEAM_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_INFORMATION_HEADER
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -126,7 +120,6 @@ def calculate_reserved_tokens(
|
||||
def build_reminder_message(
|
||||
reminder_text: str | None,
|
||||
include_citation_reminder: bool,
|
||||
include_file_reminder: bool,
|
||||
is_last_cycle: bool,
|
||||
) -> str | None:
|
||||
reminder = reminder_text.strip() if reminder_text else ""
|
||||
@@ -134,65 +127,10 @@ def build_reminder_message(
|
||||
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
|
||||
if include_citation_reminder:
|
||||
reminder += "\n\n" + CITATION_REMINDER
|
||||
if include_file_reminder:
|
||||
reminder += "\n\n" + FILE_REMINDER
|
||||
reminder = reminder.strip()
|
||||
return reminder if reminder else None
|
||||
|
||||
|
||||
def _build_user_information_section(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
company_context: str | None,
|
||||
) -> str:
|
||||
"""Build the complete '# User Information' section with all sub-sections
|
||||
in the correct order: Basic Info → Team Info → Preferences → Memories."""
|
||||
sections: list[str] = []
|
||||
|
||||
if user_memory_context:
|
||||
ctx = user_memory_context
|
||||
has_basic_info = ctx.user_info.name or ctx.user_info.email or ctx.user_info.role
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=ctx.user_info.role).strip()
|
||||
if ctx.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=ctx.user_info.name or "",
|
||||
user_email=ctx.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if company_context:
|
||||
sections.append(
|
||||
TEAM_INFORMATION_PROMPT.format(team_information=company_context.strip())
|
||||
)
|
||||
|
||||
if user_memory_context:
|
||||
ctx = user_memory_context
|
||||
|
||||
if ctx.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=ctx.user_preferences)
|
||||
)
|
||||
|
||||
if ctx.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in ctx.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
base_system_prompt: str,
|
||||
datetime_aware: bool = False,
|
||||
@@ -200,12 +138,18 @@ def build_system_prompt(
|
||||
tools: Sequence[Tool] | None = None,
|
||||
should_cite_documents: bool = False,
|
||||
include_all_guidance: bool = False,
|
||||
open_ai_formatting_enabled: bool = False,
|
||||
) -> str:
|
||||
"""Should only be called with the default behavior system prompt.
|
||||
If the user has replaced the default behavior prompt with their custom agent prompt, do not call this function.
|
||||
"""
|
||||
system_prompt = handle_onyx_date_awareness(base_system_prompt, datetime_aware)
|
||||
|
||||
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
|
||||
# for OpenAI reasoning models to have correct markdown generation
|
||||
if open_ai_formatting_enabled:
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
|
||||
# Replace citation guidance placeholder if present
|
||||
system_prompt, should_append_citation_guidance = replace_citation_guidance_tag(
|
||||
system_prompt,
|
||||
@@ -213,14 +157,16 @@ def build_system_prompt(
|
||||
include_all_guidance=include_all_guidance,
|
||||
)
|
||||
|
||||
# Replace reminder tag placeholder if present
|
||||
system_prompt = replace_reminder_tag(system_prompt)
|
||||
|
||||
company_context = get_company_context()
|
||||
user_info_section = _build_user_information_section(
|
||||
user_memory_context, company_context
|
||||
formatted_user_context = (
|
||||
user_memory_context.as_formatted_prompt() if user_memory_context else ""
|
||||
)
|
||||
system_prompt += user_info_section
|
||||
if company_context or formatted_user_context:
|
||||
system_prompt += USER_INFORMATION_HEADER
|
||||
if company_context:
|
||||
system_prompt += company_context
|
||||
if formatted_user_context:
|
||||
system_prompt += formatted_user_context
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
|
||||
@@ -251,9 +251,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
|
||||
)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
|
||||
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
|
||||
)
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
@@ -265,18 +263,6 @@ OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
|
||||
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
|
||||
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
|
||||
OPENSEARCH_EXPLAIN_ENABLED = (
|
||||
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
|
||||
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
|
||||
# existing indices need reindexing after a change.
|
||||
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
@@ -284,9 +270,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
# in case that doesn't work for whatever reason.
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
@@ -642,14 +625,6 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
|
||||
# When False (default), only groups found in site role assignments are synced.
|
||||
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
|
||||
# connector_specific_config.
|
||||
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
|
||||
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
|
||||
)
|
||||
|
||||
BLOB_STORAGE_SIZE_THRESHOLD = int(
|
||||
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
@@ -1002,7 +977,6 @@ API_KEY_HASH_ROUNDS = (
|
||||
# MCP Server Configs
|
||||
#####
|
||||
MCP_SERVER_ENABLED = os.environ.get("MCP_SERVER_ENABLED", "").lower() == "true"
|
||||
MCP_SERVER_HOST = os.environ.get("MCP_SERVER_HOST", "0.0.0.0")
|
||||
MCP_SERVER_PORT = int(os.environ.get("MCP_SERVER_PORT") or 8090)
|
||||
|
||||
# CORS origins for MCP clients (comma-separated)
|
||||
|
||||
@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import contextvars
|
||||
import re
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -15,7 +14,6 @@ from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
@@ -64,44 +62,11 @@ class AirtableClientNotSetUpError(PermissionError):
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
# Matches URLs like https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
|
||||
# Captures: base_id (appXXX), table_id (tblYYY), and optionally view_id (viwZZZ)
|
||||
_AIRTABLE_URL_PATTERN = re.compile(
|
||||
r"https?://airtable\.com/(app[A-Za-z0-9]+)/(tbl[A-Za-z0-9]+)(?:/(viw[A-Za-z0-9]+))?",
|
||||
)
|
||||
|
||||
|
||||
def parse_airtable_url(
|
||||
url: str,
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Parse an Airtable URL into (base_id, table_id, view_id).
|
||||
|
||||
Accepts URLs like:
|
||||
https://airtable.com/appXXX/tblYYY
|
||||
https://airtable.com/appXXX/tblYYY/viwZZZ
|
||||
https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
|
||||
|
||||
Returns:
|
||||
(base_id, table_id, view_id or None)
|
||||
|
||||
Raises:
|
||||
ValueError if the URL doesn't match the expected format.
|
||||
"""
|
||||
match = _AIRTABLE_URL_PATTERN.search(url.strip())
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Could not parse Airtable URL: '{url}'. "
|
||||
"Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
|
||||
)
|
||||
return match.group(1), match.group(2), match.group(3)
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str = "",
|
||||
table_name_or_id: str = "",
|
||||
airtable_url: str = "",
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
view_id: str | None = None,
|
||||
share_id: str | None = None,
|
||||
@@ -110,33 +75,16 @@ class AirtableConnector(LoadConnector):
|
||||
"""Initialize an AirtableConnector.
|
||||
|
||||
Args:
|
||||
base_id: The ID of the Airtable base (not required when airtable_url is set)
|
||||
table_name_or_id: The name or ID of the table (not required when airtable_url is set)
|
||||
airtable_url: An Airtable URL to parse base_id, table_id, and view_id from.
|
||||
Overrides base_id, table_name_or_id, and view_id if provided.
|
||||
base_id: The ID of the Airtable base to connect to
|
||||
table_name_or_id: The name or ID of the table to index
|
||||
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
|
||||
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
|
||||
view_id: Optional ID of a specific view to use
|
||||
share_id: Optional ID of a "share" to use for generating record URLs
|
||||
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
|
||||
batch_size: Number of records to process in each batch
|
||||
|
||||
Mode is auto-detected: if a specific table is identified (via URL or
|
||||
base_id + table_name_or_id), the connector indexes that single table.
|
||||
Otherwise, it discovers and indexes all accessible bases and tables.
|
||||
"""
|
||||
# If a URL is provided, parse it to extract base_id, table_id, and view_id
|
||||
if airtable_url:
|
||||
parsed_base_id, parsed_table_id, parsed_view_id = parse_airtable_url(
|
||||
airtable_url
|
||||
)
|
||||
base_id = parsed_base_id
|
||||
table_name_or_id = parsed_table_id
|
||||
if parsed_view_id:
|
||||
view_id = parsed_view_id
|
||||
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.index_all = not (base_id and table_name_or_id)
|
||||
self.view_id = view_id
|
||||
self.share_id = share_id
|
||||
self.batch_size = batch_size
|
||||
@@ -155,33 +103,6 @@ class AirtableConnector(LoadConnector):
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.index_all:
|
||||
try:
|
||||
bases = self.airtable_client.bases()
|
||||
if not bases:
|
||||
raise ConnectorValidationError(
|
||||
"No bases found. Ensure your API token has access to at least one base."
|
||||
)
|
||||
except ConnectorValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to list Airtable bases: {e}")
|
||||
else:
|
||||
if not self.base_id or not self.table_name_or_id:
|
||||
raise ConnectorValidationError(
|
||||
"A valid Airtable URL or base_id and table_name_or_id are required "
|
||||
"when not using index_all mode."
|
||||
)
|
||||
try:
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table.schema()
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to access table '{self.table_name_or_id}' "
|
||||
f"in base '{self.base_id}': {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_record_url(
|
||||
cls,
|
||||
@@ -346,7 +267,6 @@ class AirtableConnector(LoadConnector):
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
@@ -371,7 +291,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name=field_name,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=base_id,
|
||||
base_id=self.base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
@@ -406,17 +326,15 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
base_id: str,
|
||||
base_name: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
base_id: The ID of the base this record belongs to
|
||||
base_name: The name of the base (used in semantic ID for index_all mode)
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
@@ -449,7 +367,6 @@ class AirtableConnector(LoadConnector):
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
base_id=base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
@@ -462,26 +379,11 @@ class AirtableConnector(LoadConnector):
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
# Include base name in semantic ID only in index_all mode
|
||||
if self.index_all and base_name:
|
||||
semantic_id = (
|
||||
f"{base_name} > {table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else f"{base_name} > {table_name}"
|
||||
)
|
||||
else:
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
# Build hierarchy source_path for Craft file system subdirectory structure.
|
||||
# This creates: airtable/{base_name}/{table_name}/record.json
|
||||
source_path: list[str] = []
|
||||
if base_name:
|
||||
source_path.append(base_name)
|
||||
source_path.append(table_name)
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
@@ -489,39 +391,19 @@ class AirtableConnector(LoadConnector):
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": source_path,
|
||||
"base_id": base_id,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
**({"base_name": base_name} if base_name else {}),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _resolve_base_name(self, base_id: str) -> str | None:
|
||||
"""Try to resolve a human-readable base name from the API."""
|
||||
try:
|
||||
for base_info in self.airtable_client.bases():
|
||||
if base_info.id == base_id:
|
||||
return base_info.name
|
||||
except Exception:
|
||||
logger.debug(f"Could not resolve base name for {base_id}")
|
||||
return None
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
def _index_table(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
base_name: str | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Index all records from a single table. Yields batches of Documents."""
|
||||
# Resolve base name for hierarchy if not provided
|
||||
if base_name is None:
|
||||
base_name = self._resolve_base_name(base_id)
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(base_id, table_name_or_id)
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
@@ -533,25 +415,21 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Processing {len(records)} records from table "
|
||||
f"'{table_schema.name}' in base '{base_name or base_id}'."
|
||||
)
|
||||
|
||||
if not records:
|
||||
return
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
record_documents = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record: dict[Future[Document | None], RecordDict] = {}
|
||||
future_to_record: dict[Future, RecordDict] = {}
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
@@ -562,8 +440,6 @@ class AirtableConnector(LoadConnector):
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
base_id=base_id,
|
||||
base_name=base_name,
|
||||
)
|
||||
] = record
|
||||
|
||||
@@ -578,58 +454,9 @@ class AirtableConnector(LoadConnector):
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from one or all tables.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
if self.index_all:
|
||||
yield from self._load_all()
|
||||
else:
|
||||
yield from self._index_table(
|
||||
base_id=self.base_id,
|
||||
table_name_or_id=self.table_name_or_id,
|
||||
)
|
||||
|
||||
def _load_all(self) -> GenerateDocumentsOutput:
|
||||
"""Discover all bases and tables, then index everything."""
|
||||
bases = self.airtable_client.bases()
|
||||
logger.info(f"Discovered {len(bases)} Airtable base(s).")
|
||||
|
||||
for base_info in bases:
|
||||
base_id = base_info.id
|
||||
base_name = base_info.name
|
||||
logger.info(f"Listing tables for base '{base_name}' ({base_id}).")
|
||||
|
||||
try:
|
||||
base = self.airtable_client.base(base_id)
|
||||
tables = base.tables()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to list tables for base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Found {len(tables)} table(s) in base '{base_name}'.")
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
yield from self._index_table(
|
||||
base_id=base_id,
|
||||
table_name_or_id=table.id,
|
||||
base_name=base_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to index table '{table.name}' ({table.id}) "
|
||||
f"in base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
||||
@@ -905,10 +905,6 @@ class ConfluenceConnector(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
|
||||
# Yield space hierarchy nodes first
|
||||
for node in self._yield_space_hierarchy_nodes():
|
||||
doc_metadata_list.append(node)
|
||||
|
||||
def get_external_access(
|
||||
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
|
||||
) -> ExternalAccess | None:
|
||||
@@ -923,10 +919,6 @@ class ConfluenceConnector(
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
# Yield ancestor hierarchy nodes for this page
|
||||
for node in self._yield_ancestor_hierarchy_nodes(page):
|
||||
doc_metadata_list.append(node)
|
||||
|
||||
page_id = _get_page_id(page)
|
||||
page_restrictions = page.get("restrictions") or {}
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
@@ -947,7 +939,6 @@ class ConfluenceConnector(
|
||||
)
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
@@ -961,14 +952,6 @@ class ConfluenceConnector(
|
||||
):
|
||||
continue
|
||||
|
||||
# If this page has valid attachments and we haven't yielded it as a
|
||||
# hierarchy node yet, do so now (attachments are children of the page)
|
||||
if not page_hierarchy_node_yielded:
|
||||
page_node = self._maybe_yield_page_hierarchy_node(page)
|
||||
if page_node:
|
||||
doc_metadata_list.append(page_node)
|
||||
page_hierarchy_node_yielded = True
|
||||
|
||||
attachment_restrictions = attachment.get("restrictions", {})
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions or {}
|
||||
|
||||
@@ -46,7 +46,6 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
@@ -157,7 +156,10 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
|
||||
return False
|
||||
|
||||
# For shared drive content, the root has id == driveId
|
||||
return bool(drive_id and folder_id == drive_id)
|
||||
if drive_id and folder_id == drive_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _public_access() -> ExternalAccess:
|
||||
@@ -614,16 +616,6 @@ class GoogleDriveConnector(
|
||||
# empty parents due to permission limitations)
|
||||
# Check shared drive root first (simple ID comparison)
|
||||
if _is_shared_drive_root(folder):
|
||||
# files().get() returns 'Drive' for shared drive roots;
|
||||
# fetch the real name via drives().get().
|
||||
# Try both the retriever and admin since the admin may
|
||||
# not have access to private shared drives.
|
||||
drive_name = self._get_shared_drive_name(
|
||||
current_id, file.user_email
|
||||
)
|
||||
if drive_name:
|
||||
node.display_name = drive_name
|
||||
node.node_type = HierarchyNodeType.SHARED_DRIVE
|
||||
reached_terminal = True
|
||||
break
|
||||
|
||||
@@ -699,15 +691,6 @@ class GoogleDriveConnector(
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
|
||||
"""Fetch the name of a shared drive, trying both the retriever and admin."""
|
||||
for email in {retriever_email, self.primary_admin_email}:
|
||||
svc = get_drive_service(self.creds, email)
|
||||
name = get_shared_drive_name(svc, drive_id)
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_drive_ids(self) -> set[str]:
|
||||
return self._get_all_drives_for_user(self.primary_admin_email)
|
||||
|
||||
|
||||
@@ -154,26 +154,6 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return HIERARCHY_FIELDS
|
||||
|
||||
|
||||
def get_shared_drive_name(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
) -> str | None:
|
||||
"""Fetch the actual name of a shared drive via the drives().get() API.
|
||||
|
||||
The files().get() API returns 'Drive' as the name for shared drive root
|
||||
folders. Only drives().get() returns the real user-assigned name.
|
||||
"""
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id, fields="name").execute()
|
||||
return drive.get("name")
|
||||
except HttpError as e:
|
||||
if e.resp.status in (403, 404):
|
||||
logger.debug(f"Cannot access drive {drive_id}: {e}")
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def get_external_access_for_folder(
|
||||
folder: GoogleDriveFileType,
|
||||
google_domain: str,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -50,15 +50,12 @@ class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
todo_team_ids: list[str] | None = None
|
||||
|
||||
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -66,15 +63,11 @@ class TeamsConnector(
|
||||
# are not necessarily guaranteed to be unique
|
||||
teams: list[str] = [],
|
||||
max_workers: int = MAX_WORKERS,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
) -> None:
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.max_workers = max_workers
|
||||
self.requested_team_list: list[str] = teams
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
|
||||
# impls for BaseConnector
|
||||
|
||||
@@ -83,7 +76,7 @@ class TeamsConnector(
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
|
||||
authority_url = f"{self.authority_host}/{teams_directory_id}"
|
||||
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
@@ -98,7 +91,7 @@ class TeamsConnector(
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=[f"{self.graph_api_host}/.default"]
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
|
||||
if not isinstance(token, dict):
|
||||
|
||||
@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
|
||||
def slack_retrieval(
|
||||
query: ChunkIndexRequest,
|
||||
access_token: str,
|
||||
db_session: Session | None = None,
|
||||
db_session: Session,
|
||||
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
|
||||
entities: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
slack_event_context: SlackContext | None = None,
|
||||
bot_token: str | None = None, # Add bot token parameter
|
||||
team_id: str | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB query (no session needed)
|
||||
search_settings: SearchSettings | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Main entry point for Slack federated search with entity filtering.
|
||||
@@ -928,7 +925,7 @@ def slack_retrieval(
|
||||
Args:
|
||||
query: Search query object
|
||||
access_token: User OAuth access token
|
||||
db_session: Database session (optional if search_settings provided)
|
||||
db_session: Database session
|
||||
connector: Federated connector detail (unused, kept for backwards compat)
|
||||
entities: Connector-level config (entity filtering configuration)
|
||||
limit: Maximum number of results
|
||||
@@ -1156,10 +1153,7 @@ def slack_retrieval(
|
||||
|
||||
# chunk index docs into doc aware chunks
|
||||
# a single index doc can get split into multiple chunks
|
||||
if search_settings is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or search_settings must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedder = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
@@ -96,6 +97,21 @@ class IndexFilters(BaseFilters, UserFileFilters, AssistantKnowledgeFilters):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class ChunkContext(BaseModel):
|
||||
# If not specified (None), picked up from Persona settings if there is space
|
||||
# if specified (even if 0), it always uses the specified number of chunks above and below
|
||||
chunks_above: int | None = None
|
||||
chunks_below: int | None = None
|
||||
full_doc: bool = False
|
||||
|
||||
@field_validator("chunks_above", "chunks_below")
|
||||
@classmethod
|
||||
def check_non_negative(cls, value: int, field: Any) -> int:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError(f"{field.name} must be non-negative")
|
||||
return value
|
||||
|
||||
|
||||
class BasicChunkRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -43,7 +41,7 @@ def _build_index_filters(
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session | None = None,
|
||||
db_session: Session,
|
||||
auto_detect_filters: bool = False,
|
||||
query: str | None = None,
|
||||
llm: LLM | None = None,
|
||||
@@ -51,8 +49,6 @@ def _build_index_filters(
|
||||
# Assistant knowledge filters
|
||||
attached_document_ids: list[str] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
# Pre-fetched ACL filters (skips DB query when provided)
|
||||
acl_filters: list[str] | None = None,
|
||||
) -> IndexFilters:
|
||||
if auto_detect_filters and (llm is None or query is None):
|
||||
raise RuntimeError("LLM and query are required for auto detect filters")
|
||||
@@ -107,14 +103,9 @@ def _build_index_filters(
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
user_acl_filters = acl_filters
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or acl_filters must be provided")
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
@@ -261,15 +252,11 @@ def search_pipeline(
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
db_session: Session | None = None,
|
||||
db_session: Session,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
@@ -310,7 +297,6 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
acl_filters=acl_filters,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
@@ -329,8 +315,6 @@ def search_pipeline(
|
||||
user_id=user.id if user else None,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
|
||||
)
|
||||
|
||||
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean
|
||||
|
||||
@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -52,14 +50,9 @@ def combine_retrieval_results(
|
||||
def _embed_and_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_query_embedding(
|
||||
query_request.query,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
query_embedding = get_query_embedding(query_request.query, db_session)
|
||||
|
||||
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
|
||||
|
||||
@@ -85,9 +78,7 @@ def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
@@ -97,22 +88,14 @@ def search_chunks(
|
||||
else None
|
||||
)
|
||||
|
||||
# Federated retrieval — use pre-fetched if available, otherwise query DB
|
||||
if prefetched_federated_retrieval_infos is not None:
|
||||
federated_retrieval_infos = prefetched_federated_retrieval_infos
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError(
|
||||
"Either db_session or prefetched_federated_retrieval_infos "
|
||||
"must be provided"
|
||||
)
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -131,10 +114,7 @@ def search_chunks(
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
(_embed_and_search, (query_request, document_index, db_session))
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
|
||||
@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
|
||||
)
|
||||
|
||||
|
||||
def get_query_embeddings(
|
||||
queries: list[str],
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[Embedding]:
|
||||
if embedding_model is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or embedding_model must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_query_embedding(
|
||||
query: str,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> Embedding:
|
||||
return get_query_embeddings(
|
||||
[query], db_session=db_session, embedding_model=embedding_model
|
||||
)[0]
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
return get_query_embeddings([query], db_session)[0]
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -19,6 +19,7 @@ from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
@@ -671,6 +672,27 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_search_docs_table_with_relevance(
|
||||
db_session: Session,
|
||||
reference_db_search_docs: list[DBSearchDoc],
|
||||
relevance_summary: DocumentRelevance,
|
||||
) -> None:
|
||||
for search_doc in reference_db_search_docs:
|
||||
relevance_data = relevance_summary.relevance_summaries.get(
|
||||
search_doc.document_id
|
||||
)
|
||||
if relevance_data is not None:
|
||||
db_session.execute(
|
||||
update(DBSearchDoc)
|
||||
.where(DBSearchDoc.id == search_doc.id)
|
||||
.values(
|
||||
is_relevant=relevance_data.relevant,
|
||||
relevance_explanation=relevance_data.content,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
|
||||
@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""Get connector credential pairs for a user.
|
||||
|
||||
Args:
|
||||
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
|
||||
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
|
||||
defer_connector_config: If True, skips loading Connector.connector_specific_config
|
||||
to avoid fetching large JSONB blobs when they aren't needed.
|
||||
"""
|
||||
if eager_load_user:
|
||||
assert (
|
||||
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
connector_load = selectinload(ConnectorCredentialPair.connector)
|
||||
if defer_connector_config:
|
||||
connector_load = connector_load.defer(Connector.connector_specific_config)
|
||||
stmt = stmt.options(connector_load)
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc=order_by_desc,
|
||||
source=source,
|
||||
processing_mode=processing_mode,
|
||||
defer_connector_config=defer_connector_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Base Data Access Layer (DAL) for database operations.
|
||||
|
||||
The DAL pattern groups related database operations into cohesive classes
|
||||
with explicit session management. It supports two usage modes:
|
||||
|
||||
1. **External session** (FastAPI endpoints) — the caller provides a session
|
||||
whose lifecycle is managed by FastAPI's dependency injection.
|
||||
|
||||
2. **Self-managed session** (Celery tasks, scripts) — the DAL creates its
|
||||
own session via the tenant-aware session factory.
|
||||
|
||||
Subclasses add domain-specific query methods while inheriting session
|
||||
management. See ``ee.onyx.db.scim.ScimDAL`` for a concrete example.
|
||||
|
||||
Example (FastAPI)::
|
||||
|
||||
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
@router.get("/users")
|
||||
def list_users(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
|
||||
return dal.list_user_mappings(...)
|
||||
|
||||
Example (Celery)::
|
||||
|
||||
with ScimDAL.from_tenant("tenant_abc") as dal:
|
||||
dal.create_user_mapping(...)
|
||||
dal.commit()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
|
||||
class DAL:
|
||||
"""Base Data Access Layer.
|
||||
|
||||
Holds a SQLAlchemy session and provides transaction control helpers.
|
||||
Subclasses add domain-specific query methods.
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: Session) -> None:
|
||||
self._session = db_session
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""Direct access to the underlying session for advanced use cases."""
|
||||
return self._session
|
||||
|
||||
def commit(self) -> None:
|
||||
self._session.commit()
|
||||
|
||||
def flush(self) -> None:
|
||||
self._session.flush()
|
||||
|
||||
def rollback(self) -> None:
|
||||
self._session.rollback()
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def from_tenant(cls, tenant_id: str) -> Generator["DAL", None, None]:
|
||||
"""Create a DAL with a self-managed session for the given tenant.
|
||||
|
||||
The session is automatically closed when the context manager exits.
|
||||
The caller must explicitly call ``commit()`` to persist changes.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
yield cls(session)
|
||||
@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
|
||||
stmt = (
|
||||
select(DocumentSetDBModel)
|
||||
.distinct()
|
||||
.options(
|
||||
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSetDBModel.users),
|
||||
selectinload(DocumentSetDBModel.groups),
|
||||
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
)
|
||||
.options(selectinload(DocumentSetDBModel.federated_connectors))
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_documents_for_document_set_paginated(
|
||||
|
||||
@@ -232,12 +232,6 @@ class BuildSessionStatus(str, PyEnum):
|
||||
IDLE = "idle"
|
||||
|
||||
|
||||
class SharingScope(str, PyEnum):
|
||||
PRIVATE = "private"
|
||||
PUBLIC_ORG = "public_org"
|
||||
PUBLIC_GLOBAL = "public_global"
|
||||
|
||||
|
||||
class SandboxStatus(str, PyEnum):
|
||||
PROVISIONING = "provisioning"
|
||||
RUNNING = "running"
|
||||
@@ -302,4 +296,4 @@ class HierarchyNodeType(str, PyEnum):
|
||||
class LLMModelFlowType(str, PyEnum):
|
||||
CHAT = "chat"
|
||||
VISION = "vision"
|
||||
CONTEXTUAL_RAG = "contextual_rag"
|
||||
EMBEDDINGS = "embeddings"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user