mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 00:35:46 +00:00
Compare commits
5 Commits
main
...
fix-admin-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
116cae60c2 | ||
|
|
eef232d747 | ||
|
|
3cdbf3882a | ||
|
|
77ad40057b | ||
|
|
969e3d3041 |
@@ -1 +0,0 @@
|
||||
../.cursor/skills
|
||||
@@ -1,248 +0,0 @@
|
||||
---
|
||||
name: playwright-e2e-tests
|
||||
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
|
||||
---
|
||||
|
||||
# Playwright E2E Tests
|
||||
|
||||
## Project Layout
|
||||
|
||||
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
|
||||
- **Config**: `web/playwright.config.ts`
|
||||
- **Utilities**: `web/tests/e2e/utils/`
|
||||
- **Constants**: `web/tests/e2e/constants.ts`
|
||||
- **Global setup**: `web/tests/e2e/global-setup.ts`
|
||||
- **Output**: `web/output/playwright/`
|
||||
|
||||
## Imports
|
||||
|
||||
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
|
||||
```
|
||||
|
||||
All new files should be `.ts`, not `.js`.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test file
|
||||
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
|
||||
|
||||
# Run a specific project
|
||||
npx playwright test --project admin
|
||||
npx playwright test --project exclusive
|
||||
```
|
||||
|
||||
## Test Projects
|
||||
|
||||
| Project | Description | Parallelism |
|
||||
|---------|-------------|-------------|
|
||||
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
|
||||
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
|
||||
|
||||
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
|
||||
|
||||
## Authentication
|
||||
|
||||
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
|
||||
|
||||
- Server readiness check (polls health endpoint, 60s timeout)
|
||||
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
|
||||
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
|
||||
- Setting display name to `"worker"` for each worker user
|
||||
- Promoting admin2 to admin role
|
||||
- Ensuring a public LLM provider exists
|
||||
|
||||
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
|
||||
|
||||
When a test needs a different user, use API-based login — never drive the login UI:
|
||||
|
||||
```typescript
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin2");
|
||||
|
||||
// Log in as the worker-specific user (preferred for test isolation):
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
Tests start pre-authenticated as admin — navigate and test directly:
|
||||
|
||||
```typescript
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
test.describe("Feature Name", () => {
|
||||
test("should describe expected behavior clearly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// Already authenticated as admin — go straight to testing
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
|
||||
|
||||
```typescript
|
||||
import { test } from "@playwright/test";
|
||||
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.beforeEach(async ({ page }, testInfo) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsWorkerUser(page, testInfo.workerIndex);
|
||||
});
|
||||
```
|
||||
|
||||
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
|
||||
|
||||
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
|
||||
|
||||
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
|
||||
|
||||
## Key Utilities
|
||||
|
||||
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
|
||||
|
||||
Backend API client for test setup/teardown. Key methods:
|
||||
|
||||
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
|
||||
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
|
||||
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
|
||||
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
|
||||
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
|
||||
- **Chat**: `createChatSession()`, `deleteChatSession()`
|
||||
|
||||
### `chatActions` (`@tests/e2e/utils/chatActions`)
|
||||
|
||||
- `sendMessage(page, message)` — sends a message and waits for AI response
|
||||
- `startNewChat(page)` — clicks new-chat button and waits for intro
|
||||
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
|
||||
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
|
||||
- `switchModel(page, modelName)` — switches LLM model via popover
|
||||
|
||||
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
|
||||
|
||||
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
|
||||
- `expectElementScreenshot(locator, { name, mask?, hide? })`
|
||||
- Controlled by `VISUAL_REGRESSION=true` env var
|
||||
|
||||
### `theme` (`@tests/e2e/utils/theme`)
|
||||
|
||||
- `THEMES` — `["light", "dark"] as const` array for iterating over both themes
|
||||
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
|
||||
|
||||
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
|
||||
|
||||
```typescript
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Feature (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
test("renders correctly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await expectScreenshot(page, { name: `feature-${theme}` });
|
||||
});
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### `tools` (`@tests/e2e/utils/tools`)
|
||||
|
||||
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
|
||||
- `openActionManagement(page)` — opens the tool management popover
|
||||
|
||||
## Locator Strategy
|
||||
|
||||
Use locators in this priority order:
|
||||
|
||||
1. **`data-testid` / `aria-label`** — preferred for Onyx components
|
||||
```typescript
|
||||
page.getByTestId("AppSidebar/new-session")
|
||||
page.getByLabel("admin-page-title")
|
||||
```
|
||||
|
||||
2. **Role-based** — for standard HTML elements
|
||||
```typescript
|
||||
page.getByRole("button", { name: "Create" })
|
||||
page.getByRole("dialog")
|
||||
```
|
||||
|
||||
3. **Text/Label** — for visible text content
|
||||
```typescript
|
||||
page.getByText("Custom Assistant")
|
||||
page.getByLabel("Email")
|
||||
```
|
||||
|
||||
4. **CSS selectors** — last resort, only when above won't work
|
||||
```typescript
|
||||
page.locator('input[name="name"]')
|
||||
page.locator("#onyx-chat-input-textarea")
|
||||
```
|
||||
|
||||
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
|
||||
|
||||
## Assertions
|
||||
|
||||
Use web-first assertions — they auto-retry until the condition is met:
|
||||
|
||||
```typescript
|
||||
// Visibility
|
||||
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Text content
|
||||
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
|
||||
|
||||
// Count
|
||||
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
|
||||
|
||||
// URL
|
||||
await expect(page).toHaveURL(/chatId=/);
|
||||
|
||||
// Element state
|
||||
await expect(toggle).toBeChecked();
|
||||
await expect(button).toBeEnabled();
|
||||
```
|
||||
|
||||
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
|
||||
|
||||
## Waiting Strategy
|
||||
|
||||
```typescript
|
||||
// Wait for load state after navigation
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Wait for specific element
|
||||
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// Wait for URL change
|
||||
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
|
||||
|
||||
// Wait for network response
|
||||
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
|
||||
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
|
||||
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
|
||||
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
|
||||
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
|
||||
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
|
||||
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
|
||||
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
|
||||
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
|
||||
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
151
.github/workflows/nightly-scan-licenses.yml
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
3
.github/workflows/pr-helm-chart-testing.yml
vendored
3
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,8 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
9
.github/workflows/pr-playwright-tests.yml
vendored
9
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,9 +22,6 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -303,7 +300,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
@@ -593,10 +589,7 @@ jobs:
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: >-
|
||||
always() &&
|
||||
github.event_name == 'pull_request' &&
|
||||
needs.playwright-tests.result != 'cancelled'
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
|
||||
13
.github/workflows/zizmor.yml
vendored
13
.github/workflows/zizmor.yml
vendored
@@ -5,8 +5,6 @@ on:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
paths:
|
||||
- ".github/**"
|
||||
|
||||
permissions: {}
|
||||
|
||||
@@ -23,18 +21,29 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,7 +7,6 @@
|
||||
.zed
|
||||
.cursor
|
||||
!/.cursor/mcp.json
|
||||
!/.cursor/skills/
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
|
||||
@@ -263,15 +263,9 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=source,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,604 +0,0 @@
|
||||
"""SCIM Data Access Layer.
|
||||
|
||||
All database operations for SCIM provisioning — token management, user
|
||||
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
|
||||
|
||||
Usage from FastAPI::
|
||||
|
||||
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
@router.post("/tokens")
|
||||
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
|
||||
token = dal.create_token(name=..., hashed_token=..., ...)
|
||||
dal.commit()
|
||||
return token
|
||||
|
||||
Usage from background tasks::
|
||||
|
||||
with ScimDAL.from_tenant("tenant_abc") as dal:
|
||||
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
|
||||
dal.commit()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import SQLColumnExpression
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimDAL(DAL):
|
||||
"""Data Access Layer for SCIM provisioning operations.
|
||||
|
||||
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
|
||||
when you want to persist changes. This follows the existing ``_no_commit``
|
||||
convention and lets callers batch multiple operations into one transaction.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
name: str,
|
||||
hashed_token: str,
|
||||
token_display: str,
|
||||
created_by_id: UUID,
|
||||
) -> ScimToken:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — this method automatically revokes
|
||||
all existing active tokens before creating the new one.
|
||||
"""
|
||||
# Revoke any currently active tokens
|
||||
active_tokens = list(
|
||||
self._session.scalars(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
).all()
|
||||
)
|
||||
for t in active_tokens:
|
||||
t.is_active = False
|
||||
|
||||
token = ScimToken(
|
||||
name=name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=created_by_id,
|
||||
)
|
||||
self._session.add(token)
|
||||
self._session.flush()
|
||||
return token
|
||||
|
||||
def get_active_token(self) -> ScimToken | None:
|
||||
"""Return the single currently active token, or None."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.is_active.is_(True))
|
||||
)
|
||||
|
||||
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
|
||||
"""Look up a token by its SHA-256 hash."""
|
||||
return self._session.scalar(
|
||||
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
|
||||
)
|
||||
|
||||
def revoke_token(self, token_id: int) -> None:
|
||||
"""Deactivate a token by ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token does not exist.
|
||||
"""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if not token:
|
||||
raise ValueError(f"SCIM token with id {token_id} not found")
|
||||
token.is_active = False
|
||||
|
||||
def update_token_last_used(self, token_id: int) -> None:
|
||||
"""Update the last_used_at timestamp for a token."""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if token:
|
||||
token.last_used_at = func.now() # type: ignore[assignment]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_user_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
|
||||
"""Look up a user mapping by the Onyx user ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
|
||||
)
|
||||
|
||||
def list_user_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimUserMapping], int]:
|
||||
"""List user mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimUserMapping)
|
||||
.order_by(ScimUserMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def update_user_mapping_external_id(
|
||||
self,
|
||||
mapping_id: int,
|
||||
external_id: str,
|
||||
) -> ScimUserMapping:
|
||||
"""Update the external ID on a user mapping.
|
||||
|
||||
Raises:
|
||||
ValueError: If the mapping does not exist.
|
||||
"""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
|
||||
mapping.external_id = external_id
|
||||
return mapping
|
||||
|
||||
def delete_user_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a user mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimUserMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_user(self, user_id: UUID) -> User | None:
|
||||
"""Fetch a user by ID."""
|
||||
return self._session.scalar(
|
||||
select(User).where(User.id == user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Fetch a user by email (case-insensitive)."""
|
||||
return self._session.scalar(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
|
||||
def add_user(self, user: User) -> None:
|
||||
"""Add a new user to the session and flush to assign an ID."""
|
||||
self._session.add(user)
|
||||
self._session.flush()
|
||||
|
||||
def update_user(
|
||||
self,
|
||||
user: User,
|
||||
*,
|
||||
email: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
personal_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update user attributes. Only sets fields that are provided."""
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if is_active is not None:
|
||||
user.is_active = is_active
|
||||
if personal_name is not None:
|
||||
user.personal_name = personal_name
|
||||
|
||||
def deactivate_user(self, user: User) -> None:
|
||||
"""Mark a user as inactive."""
|
||||
user.is_active = False
|
||||
|
||||
def list_users(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "username":
|
||||
# arg-type: fastapi-users types User.email as str, not a column expression
|
||||
# assignment: union return type widens but query is still Select[tuple[User]]
|
||||
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
|
||||
elif attr == "active":
|
||||
query = query.where(
|
||||
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
|
||||
)
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
# Count total matching rows first, then paginate. SCIM uses 1-based
|
||||
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
).all()
|
||||
)
|
||||
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_group_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
user_group_id: int,
|
||||
) -> ScimGroupMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user group."""
|
||||
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
|
||||
def get_group_mapping_by_external_id(
|
||||
self, external_id: str
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the IdP's external identifier."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
|
||||
)
|
||||
|
||||
def get_group_mapping_by_group_id(
|
||||
self, user_group_id: int
|
||||
) -> ScimGroupMapping | None:
|
||||
"""Look up a group mapping by the Onyx user group ID."""
|
||||
return self._session.scalar(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
|
||||
def list_group_mappings(
|
||||
self,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[ScimGroupMapping], int]:
|
||||
"""List group mappings with SCIM-style pagination.
|
||||
|
||||
Args:
|
||||
start_index: 1-based start index (SCIM convention).
|
||||
count: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
A tuple of (mappings, total_count).
|
||||
"""
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
mappings = list(
|
||||
self._session.scalars(
|
||||
select(ScimGroupMapping)
|
||||
.order_by(ScimGroupMapping.id)
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
return mappings, total
|
||||
|
||||
def delete_group_mapping(self, mapping_id: int) -> None:
|
||||
"""Delete a group mapping by ID. No-op if already deleted."""
|
||||
mapping = self._session.get(ScimGroupMapping, mapping_id)
|
||||
if not mapping:
|
||||
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
|
||||
return
|
||||
self._session.delete(mapping)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group query operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_group(self, group_id: int) -> UserGroup | None:
|
||||
"""Fetch a group by ID, returning None if deleted or missing."""
|
||||
group = self._session.get(UserGroup, group_id)
|
||||
if group and group.is_up_for_deletion:
|
||||
return None
|
||||
return group
|
||||
|
||||
def get_group_by_name(self, name: str) -> UserGroup | None:
|
||||
"""Fetch a group by exact name."""
|
||||
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
|
||||
|
||||
def add_group(self, group: UserGroup) -> None:
|
||||
"""Add a new group to the session and flush to assign an ID."""
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
*,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Update group attributes and set the modification timestamp."""
|
||||
if name is not None:
|
||||
group.name = name
|
||||
group.time_last_modified_by_user = func.now()
|
||||
|
||||
def delete_group(self, group: UserGroup) -> None:
|
||||
"""Delete a group from the session."""
|
||||
self._session.delete(group)
|
||||
|
||||
def list_groups(
|
||||
self,
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[UserGroup, str | None]], int]:
|
||||
"""Query groups with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (group, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
|
||||
|
||||
if scim_filter:
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "displayname":
|
||||
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
|
||||
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(UserGroup.id == mapping.user_group_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
)
|
||||
|
||||
total = (
|
||||
self._session.scalar(select(func.count()).select_from(query.subquery()))
|
||||
or 0
|
||||
)
|
||||
|
||||
offset = max(start_index - 1, 0)
|
||||
groups = list(
|
||||
self._session.scalars(
|
||||
query.order_by(UserGroup.id).offset(offset).limit(count)
|
||||
).all()
|
||||
)
|
||||
|
||||
ext_id_map = self._get_group_external_ids([g.id for g in groups])
|
||||
return [(g, ext_id_map.get(g.id)) for g in groups], total
|
||||
|
||||
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
|
||||
"""Get group members as (user_id, email) pairs."""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
).all()
|
||||
|
||||
user_ids = [r.user_id for r in rels if r.user_id]
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
(
|
||||
r.user_id,
|
||||
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
|
||||
)
|
||||
for r in rels
|
||||
if r.user_id
|
||||
]
|
||||
|
||||
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
|
||||
"""Return the subset of UUIDs that don't exist as users.
|
||||
|
||||
Returns an empty list if all IDs are valid.
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Add user-group relationships, ignoring duplicates."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
pg_insert(User__UserGroup)
|
||||
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
User__UserGroup.user_group_id,
|
||||
User__UserGroup.user_id,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Replace all members of a group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
self.upsert_group_members(group_id, user_ids)
|
||||
|
||||
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
|
||||
"""Remove specific members from a group."""
|
||||
if not user_ids:
|
||||
return
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.in_(user_ids),
|
||||
)
|
||||
)
|
||||
|
||||
def delete_group_with_members(self, group: UserGroup) -> None:
|
||||
"""Remove all member relationships and delete the group."""
|
||||
self._session.execute(
|
||||
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
|
||||
)
|
||||
self._session.delete(group)
|
||||
|
||||
def sync_group_external_id(
|
||||
self, group_id: int, new_external_id: str | None
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a group."""
|
||||
mapping = self.get_group_mapping_by_group_id(group_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_group_mapping(
|
||||
external_id=new_external_id, user_group_id=group_id
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_group_mapping(mapping.id)
|
||||
|
||||
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
|
||||
"""Batch-fetch external IDs for a list of group IDs."""
|
||||
if not group_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimGroupMapping).where(
|
||||
ScimGroupMapping.user_group_id.in_(group_ids)
|
||||
)
|
||||
).all()
|
||||
return {m.user_group_id: m.external_id for m in mappings}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (used by DAL methods above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_scim_string_op(
|
||||
query: Select[tuple[User]] | Select[tuple[UserGroup]],
|
||||
column: SQLColumnExpression[str],
|
||||
scim_filter: ScimFilter,
|
||||
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
|
||||
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
|
||||
|
||||
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
|
||||
SQLAlchemy's operators handle LIKE-pattern escaping internally.
|
||||
"""
|
||||
val = scim_filter.value
|
||||
if scim_filter.operator == ScimFilterOperator.EQUAL:
|
||||
return query.where(func.lower(column) == val.lower())
|
||||
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
|
||||
return query.where(column.icontains(val, autoescape=True))
|
||||
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
|
||||
return query.where(column.istartswith(val, autoescape=True))
|
||||
else:
|
||||
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")
|
||||
@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
@@ -163,11 +162,6 @@ def get_application() -> FastAPI:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
|
||||
@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
|
||||
|
||||
|
||||
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
|
||||
# before bearer token configuration is complete
|
||||
("/scim/v2/ServiceProviderConfig", {"GET"}),
|
||||
("/scim/v2/ResourceTypes", {"GET"}),
|
||||
("/scim/v2/Schemas", {"GET"}),
|
||||
# needs to be accessible prior to user login
|
||||
("/enterprise-settings", {"GET"}),
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
|
||||
@@ -13,7 +13,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import store_settings
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
|
||||
from ee.onyx.server.scim.models import ScimTokenResponse
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
@@ -203,63 +198,3 @@ def upload_custom_analytics_script(
|
||||
@basic_router.get("/custom-analytics-script")
|
||||
def fetch_custom_analytics_script() -> str | None:
|
||||
return load_analytics_script()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM token management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
@admin_router.get("/scim/token")
|
||||
def get_active_scim_token(
|
||||
_: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenResponse:
|
||||
"""Return the currently active SCIM token's metadata, or 404 if none."""
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/scim/token", status_code=201)
|
||||
def create_scim_token(
|
||||
body: ScimTokenCreate,
|
||||
user: User = Depends(current_admin_user),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenCreatedResponse:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
Only one token is active at a time — creating a new token automatically
|
||||
revokes all previous tokens. The raw token value is returned exactly once
|
||||
in the response; it cannot be retrieved again.
|
||||
"""
|
||||
raw_token, hashed_token, token_display = generate_scim_token()
|
||||
token = dal.create_token(
|
||||
name=body.name,
|
||||
hashed_token=hashed_token,
|
||||
token_display=token_display,
|
||||
created_by_id=user.id,
|
||||
)
|
||||
dal.commit()
|
||||
|
||||
return ScimTokenCreatedResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
token_display=token.token_display,
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
raw_token=raw_token,
|
||||
)
|
||||
|
||||
@@ -1,689 +0,0 @@
|
||||
"""SCIM 2.0 API endpoints (RFC 7644).
|
||||
|
||||
This module provides the FastAPI router for SCIM service discovery,
|
||||
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
|
||||
these endpoints to provision and manage users and groups.
|
||||
|
||||
Service discovery endpoints are unauthenticated — IdPs may probe them
|
||||
before bearer token configuration is complete. All other endpoints
|
||||
require a valid SCIM bearer token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/ServiceProviderConfig")
|
||||
def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
"""Advertise supported SCIM features (RFC 7643 §5)."""
|
||||
return SERVICE_PROVIDER_CONFIG
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return JSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
user = dal.get_user(uid)
|
||||
if not user:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
|
||||
def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""Extract a display name string from a SCIM name object.
|
||||
|
||||
Returns None if no name is provided, so the caller can decide
|
||||
whether to update the user's personal_name.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip().lower()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
try:
|
||||
dal.add_user(user)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip().lower(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, new_external_id)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
``PATCH {"active": false}`` rather than DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current = _user_to_scim(user, external_id)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
dal.deactivate_user(user)
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
except ValueError:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
group = dal.get_group(gid)
|
||||
if not group:
|
||||
return _scim_error_response(404, f"Group {group_id} not found")
|
||||
return group
|
||||
|
||||
|
||||
def _parse_member_uuids(
|
||||
members: list[ScimGroupMember],
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse member value strings to UUIDs.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids: list[UUID] = []
|
||||
for m in members:
|
||||
try:
|
||||
uuids.append(UUID(m.value))
|
||||
except ValueError:
|
||||
return [], f"Invalid member ID: {m.value}"
|
||||
return uuids, None
|
||||
|
||||
|
||||
def _validate_and_parse_members(
|
||||
members: list[ScimGroupMember], dal: ScimDAL
|
||||
) -> tuple[list[UUID], str | None]:
|
||||
"""Parse and validate member UUIDs exist in the database.
|
||||
|
||||
Returns (uuid_list, error_message). error_message is None on success.
|
||||
"""
|
||||
uuids, err = _parse_member_uuids(members)
|
||||
if err:
|
||||
return [], err
|
||||
|
||||
if uuids:
|
||||
missing = dal.validate_member_ids(uuids)
|
||||
if missing:
|
||||
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
|
||||
|
||||
return uuids, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group CRUD (RFC 7644 §3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
db_group = UserGroup(
|
||||
name=group_resource.displayName,
|
||||
is_up_to_date=True,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
try:
|
||||
dal.add_group(db_group)
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
missing = dal.validate_member_ids(add_uuids)
|
||||
if missing:
|
||||
return _scim_error_response(
|
||||
400,
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
def _is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
try:
|
||||
UUID(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -1,104 +0,0 @@
|
||||
"""SCIM bearer token authentication.
|
||||
|
||||
SCIM endpoints are authenticated via bearer tokens that admins create in the
|
||||
Onyx UI. This module provides:
|
||||
|
||||
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
|
||||
validates the token from the Authorization header.
|
||||
- ``generate_scim_token``: Creates a new cryptographically random token
|
||||
and returns the raw value, its SHA-256 hash, and a display suffix.
|
||||
|
||||
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
|
||||
URL-safe base64 from ``secrets.token_urlsafe``.
|
||||
|
||||
The hash is stored in the ``scim_token`` table; the raw value is shown to
|
||||
the admin exactly once at creation time.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
|
||||
def _hash_scim_token(token: str) -> str:
|
||||
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def generate_scim_token() -> tuple[str, str, str]:
|
||||
"""Generate a new SCIM bearer token.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(raw_token, hashed_token, token_display)`` where
|
||||
``token_display`` is a masked version showing only the last 4 chars.
|
||||
"""
|
||||
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
|
||||
hashed_token = _hash_scim_token(raw_token)
|
||||
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
|
||||
return raw_token, hashed_token, token_display
|
||||
|
||||
|
||||
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
|
||||
"""Extract and hash a SCIM token from the request Authorization header."""
|
||||
return get_hashed_bearer_token_from_request(
|
||||
request,
|
||||
valid_prefixes=[SCIM_TOKEN_PREFIX],
|
||||
hash_fn=_hash_scim_token,
|
||||
)
|
||||
|
||||
|
||||
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
def verify_scim_token(
|
||||
request: Request,
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimToken:
|
||||
"""FastAPI dependency that authenticates SCIM requests.
|
||||
|
||||
Extracts the bearer token from the Authorization header, hashes it,
|
||||
looks it up in the database, and verifies it is active.
|
||||
|
||||
Note:
|
||||
This dependency does NOT update ``last_used_at`` — the endpoint
|
||||
should do that via ``ScimDAL.update_token_last_used()`` so the
|
||||
timestamp write is part of the endpoint's transaction.
|
||||
|
||||
Raises:
|
||||
HTTPException(401): If the token is missing, invalid, or inactive.
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
|
||||
return token
|
||||
@@ -30,7 +30,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -196,39 +195,10 @@ class ScimServiceProviderConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaAttribute(BaseModel):
|
||||
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
multiValued: bool = False
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
caseExact: bool = False
|
||||
mutability: str = "readWrite"
|
||||
returned: str = "default"
|
||||
uniqueness: str = "none"
|
||||
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaDefinition(BaseModel):
|
||||
"""SCIM Schema definition (RFC 7643 §7).
|
||||
|
||||
Served at GET /scim/v2/Schemas. Describes the attributes available
|
||||
on each resource type so IdPs know which fields they can provision.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
@@ -241,7 +211,7 @@ class ScimResourceType(BaseModel):
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
|
||||
|
||||
Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaAttribute
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
|
||||
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
|
||||
|
||||
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/scim/v2/Groups",
|
||||
"description": "SCIM Group resource",
|
||||
"schema": SCIM_GROUP_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_USER_SCHEMA,
|
||||
name="User",
|
||||
description="SCIM core User schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="userName",
|
||||
type="string",
|
||||
required=True,
|
||||
uniqueness="server",
|
||||
description="Unique identifier for the user, typically an email address.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="name",
|
||||
type="complex",
|
||||
description="The components of the user's name.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="givenName",
|
||||
type="string",
|
||||
description="The user's first name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="familyName",
|
||||
type="string",
|
||||
description="The user's last name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="formatted",
|
||||
type="string",
|
||||
description="The full name, including all middle names and titles.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="emails",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Email addresses for the user.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Email address value.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="type",
|
||||
type="string",
|
||||
description="Label for this email (e.g. 'work').",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="primary",
|
||||
type="boolean",
|
||||
description="Whether this is the primary email.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="active",
|
||||
type="boolean",
|
||||
description="Whether the user account is active.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
description="SCIM core Group schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="displayName",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Human-readable name for the group.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="members",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Members of the group.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="User ID of the group member.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="display",
|
||||
type="string",
|
||||
mutability="readOnly",
|
||||
description="Display name of the group member.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1,13 +1,10 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -44,14 +41,6 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -93,18 +82,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -113,7 +90,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license found in cache or DB.
|
||||
# No license found.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
|
||||
@@ -68,18 +68,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
"""Detect XML-style marshaled tool calls emitted as plain text."""
|
||||
if not text:
|
||||
return False
|
||||
lowered = text.lower()
|
||||
return (
|
||||
"<function_calls" in lowered
|
||||
and "<invoke" in lowered
|
||||
and "<parameter" in lowered
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
@@ -134,56 +122,38 @@ def _try_fallback_tool_extraction(
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
xml_tool_call_text_detected = no_tool_calls and (
|
||||
_looks_like_xml_tool_call_payload(llm_step_result.answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
|
||||
)
|
||||
should_try_fallback = (
|
||||
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
|
||||
or reasoning_but_no_answer_or_tools
|
||||
or xml_tool_call_text_detected
|
||||
)
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if (
|
||||
not extracted_tool_calls
|
||||
and llm_step_result.raw_answer
|
||||
and llm_step_result.raw_answer != llm_step_result.answer
|
||||
):
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.raw_answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
"as fallback"
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
raw_answer=llm_step_result.raw_answer,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from html import unescape
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -58,112 +56,6 @@ from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_XML_INVOKE_BLOCK_RE = re.compile(
|
||||
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_XML_PARAMETER_RE = re.compile(
|
||||
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
|
||||
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
|
||||
|
||||
|
||||
class _XmlToolCallContentFilter:
|
||||
"""Streaming filter that strips XML-style tool call payload blocks from text."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
|
||||
def process(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
self._pending += content
|
||||
output_parts: list[str] = []
|
||||
|
||||
while self._pending:
|
||||
pending_lower = self._pending.lower()
|
||||
|
||||
if self._inside_function_calls_block:
|
||||
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
|
||||
if end_idx == -1:
|
||||
# Keep buffering until we see the close marker.
|
||||
return "".join(output_parts)
|
||||
|
||||
# Drop the whole function_calls block.
|
||||
self._pending = self._pending[
|
||||
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
|
||||
]
|
||||
self._inside_function_calls_block = False
|
||||
continue
|
||||
|
||||
start_idx = _find_function_calls_open_marker(pending_lower)
|
||||
if start_idx == -1:
|
||||
# Keep only a possible prefix of "<function_calls" in the buffer so
|
||||
# marker splits across chunks are handled correctly.
|
||||
tail_len = _matching_open_marker_prefix_len(self._pending)
|
||||
emit_upto = len(self._pending) - tail_len
|
||||
if emit_upto > 0:
|
||||
output_parts.append(self._pending[:emit_upto])
|
||||
self._pending = self._pending[emit_upto:]
|
||||
return "".join(output_parts)
|
||||
|
||||
if start_idx > 0:
|
||||
output_parts.append(self._pending[:start_idx])
|
||||
|
||||
# Enter block-stripping mode and keep scanning for close marker.
|
||||
self._pending = self._pending[start_idx:]
|
||||
self._inside_function_calls_block = True
|
||||
|
||||
return "".join(output_parts)
|
||||
|
||||
def flush(self) -> str:
|
||||
if self._inside_function_calls_block:
|
||||
# Drop any incomplete block at stream end.
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
return ""
|
||||
|
||||
remaining = self._pending
|
||||
self._pending = ""
|
||||
return remaining
|
||||
|
||||
|
||||
def _matching_open_marker_prefix_len(text: str) -> int:
|
||||
"""Return longest suffix of text that matches prefix of "<function_calls"."""
|
||||
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
|
||||
text_lower = text.lower()
|
||||
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
|
||||
|
||||
for candidate_len in range(max_len, 0, -1):
|
||||
if text_lower.endswith(marker_lower[:candidate_len]):
|
||||
return candidate_len
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
|
||||
return char is None or char in {">", " ", "\t", "\n", "\r"}
|
||||
|
||||
|
||||
def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
"""Find '<function_calls' with a valid tag boundary follower."""
|
||||
search_from = 0
|
||||
while True:
|
||||
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
|
||||
if idx == -1:
|
||||
return -1
|
||||
|
||||
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
|
||||
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
|
||||
if _is_valid_function_calls_open_follower(follower):
|
||||
return idx
|
||||
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
@@ -380,7 +272,14 @@ def _extract_tool_call_kickoffs(
|
||||
tab_index_calculated = 0
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
try:
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
|
||||
)
|
||||
tool_args = {}
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
@@ -408,9 +307,8 @@ def extract_tool_calls_from_response_text(
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for tool calls embedded
|
||||
in response text (JSON first, then XML-like invoke blocks) that match available
|
||||
tool definitions.
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
@@ -435,9 +333,10 @@ def extract_tool_calls_from_response_text(
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
prev_json_obj: dict[str, Any] | None = None
|
||||
prev_tool_call: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
@@ -465,14 +364,6 @@ def extract_tool_calls_from_response_text(
|
||||
prev_json_obj = json_obj
|
||||
prev_tool_call = matched_tool_call
|
||||
|
||||
# Some providers/models emit XML-style function calls instead of JSON objects.
|
||||
# Keep this as a fallback behind JSON extraction to preserve current behavior.
|
||||
if not matched_tool_calls:
|
||||
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_name_to_def=tool_name_to_def,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
|
||||
tool_calls.append(
|
||||
@@ -495,88 +386,6 @@ def extract_tool_calls_from_response_text(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _extract_xml_tool_calls_from_response_text(
|
||||
response_text: str,
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> list[tuple[str, dict[str, Any]]]:
|
||||
"""Extract XML-style tool calls from response text.
|
||||
|
||||
Supports formats such as:
|
||||
<function_calls>
|
||||
<invoke name="internal_search">
|
||||
<parameter name="queries" string="false">["foo"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
|
||||
invoke_attrs = invoke_match.group("attrs")
|
||||
tool_name = _extract_xml_attribute(invoke_attrs, "name")
|
||||
if not tool_name or tool_name not in tool_name_to_def:
|
||||
continue
|
||||
|
||||
tool_args: dict[str, Any] = {}
|
||||
invoke_body = invoke_match.group("body")
|
||||
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
|
||||
parameter_attrs = parameter_match.group("attrs")
|
||||
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
string_attr = _extract_xml_attribute(parameter_attrs, "string")
|
||||
tool_args[parameter_name] = _parse_xml_parameter_value(
|
||||
raw_value=parameter_match.group("value"),
|
||||
string_attr=string_attr,
|
||||
)
|
||||
|
||||
matched_tool_calls.append((tool_name, tool_args))
|
||||
|
||||
return matched_tool_calls
|
||||
|
||||
|
||||
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
"""Extract a single XML-style attribute value from a tag attribute string."""
|
||||
attr_match = re.search(
|
||||
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
|
||||
attrs,
|
||||
flags=re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Extract and parse an arguments/parameters value from a tool-call-like object.
|
||||
|
||||
Looks for "arguments" or "parameters" keys, handles JSON-string values,
|
||||
and returns a dict if successful, or None otherwise.
|
||||
"""
|
||||
arguments = obj.get("arguments", obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
return None
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
@@ -599,8 +408,13 @@ def _try_match_json_to_tool(
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = _resolve_tool_arguments(json_obj)
|
||||
if arguments is not None:
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
@@ -608,8 +422,13 @@ def _try_match_json_to_tool(
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = _resolve_tool_arguments(func_obj)
|
||||
if arguments is not None:
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
@@ -879,8 +698,7 @@ def run_llm_step_pkt_generator(
|
||||
tool_definitions: List of tool definitions available to the LLM.
|
||||
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
|
||||
llm: Language model interface to use for generation.
|
||||
placement: Placement info (turn_index, tab_index, sub_turn_index) for
|
||||
positioning packets in the conversation UI.
|
||||
turn_index: Current turn index in the conversation.
|
||||
state_container: Container for storing chat state (reasoning, answers).
|
||||
citation_processor: Optional processor for extracting and formatting citations
|
||||
from the response. If provided, processes tokens to identify citations.
|
||||
@@ -892,14 +710,7 @@ def run_llm_step_pkt_generator(
|
||||
custom_token_processor: Optional callable that processes each token delta
|
||||
before yielding. Receives (delta, processor_state) and returns
|
||||
(modified_delta, new_processor_state). Can return None for delta to skip.
|
||||
max_tokens: Optional maximum number of tokens for the LLM response.
|
||||
use_existing_tab_index: If True, use the tab_index from placement for all
|
||||
tool calls instead of auto-incrementing.
|
||||
is_deep_research: If True, treat content before tool calls as reasoning
|
||||
when tool_choice is REQUIRED.
|
||||
pre_answer_processing_time: Optional time spent processing before the
|
||||
answer started, recorded in state_container for analytics.
|
||||
timeout_override: Optional timeout override for the LLM call.
|
||||
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
|
||||
|
||||
Yields:
|
||||
Packet: Streaming packets containing:
|
||||
@@ -925,15 +736,8 @@ def run_llm_step_pkt_generator(
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
def _current_placement() -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
llm_msg_history = translate_history_to_llm_format(history, llm.config)
|
||||
has_reasoned = False
|
||||
has_reasoned = 0
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.debug(
|
||||
@@ -945,8 +749,6 @@ def run_llm_step_pkt_generator(
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
accumulated_raw_answer = ""
|
||||
xml_tool_call_content_filter = _XmlToolCallContentFilter()
|
||||
|
||||
processor_state: Any = None
|
||||
|
||||
@@ -962,112 +764,6 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
stream_start_time = time.monotonic()
|
||||
first_action_recorded = False
|
||||
|
||||
def _emit_citation_results(
|
||||
results: Generator[str | CitationInfo, None, None],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Yield packets for citation processor results (str or CitationInfo)."""
|
||||
nonlocal accumulated_answer
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=result,
|
||||
)
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
|
||||
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
|
||||
nonlocal reasoning_start
|
||||
nonlocal has_reasoned
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = True
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
|
||||
nonlocal accumulated_answer
|
||||
nonlocal accumulated_reasoning
|
||||
nonlocal answer_start
|
||||
nonlocal reasoning_start
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
accumulated_reasoning += content_chunk
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=ReasoningDelta(reasoning=content_chunk),
|
||||
)
|
||||
reasoning_start = True
|
||||
return
|
||||
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
yield from _emit_citation_results(
|
||||
citation_processor.process_token(content_chunk)
|
||||
)
|
||||
else:
|
||||
accumulated_answer += content_chunk
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
obj=AgentResponseDelta(content=content_chunk),
|
||||
)
|
||||
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
@@ -1126,34 +822,152 @@ def run_llm_step_pkt_generator(
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=_current_placement(),
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
# Keep raw content for fallback extraction. Display content can be
|
||||
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
|
||||
accumulated_raw_answer += delta.content
|
||||
filtered_content = xml_tool_call_content_filter.process(delta.content)
|
||||
if filtered_content:
|
||||
yield from _emit_content_chunk(filtered_content)
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
# Treat content as reasoning when we know tool calls are coming
|
||||
accumulated_reasoning += delta.content
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.content),
|
||||
)
|
||||
reasoning_start = True
|
||||
else:
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(
|
||||
accumulated_answer
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
yield from _close_reasoning_if_active()
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
if filtered_content_tail:
|
||||
yield from _emit_content_chunk(filtered_content_tail)
|
||||
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
@@ -1209,14 +1023,50 @@ def run_llm_step_pkt_generator(
|
||||
|
||||
# This may happen if the custom token processor is used to modify other packets into reasoning
|
||||
# Then there won't necessarily be anything else to come after the reasoning tokens
|
||||
yield from _close_reasoning_if_active()
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
|
||||
reasoning_start = False
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
# Reasoning is always first so this should use the post-incremented value of turn_index
|
||||
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
|
||||
# as clickable items and will be stripped out instead.
|
||||
if citation_processor:
|
||||
yield from _emit_citation_results(citation_processor.process_token(None))
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
@@ -1238,9 +1088,8 @@ def run_llm_step_pkt_generator(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
|
||||
),
|
||||
has_reasoned,
|
||||
bool(has_reasoned),
|
||||
)
|
||||
|
||||
|
||||
@@ -1295,4 +1144,4 @@ def run_llm_step(
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, has_reasoned = e.value
|
||||
return llm_step_result, has_reasoned
|
||||
return llm_step_result, bool(has_reasoned)
|
||||
|
||||
@@ -185,6 +185,3 @@ class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
tool_calls: list[ToolCallKickoff] | None
|
||||
# Raw LLM text before any display-oriented filtering/sanitization.
|
||||
# Used for fallback tool-call extraction when providers emit calls as text.
|
||||
raw_answer: str | None = None
|
||||
|
||||
@@ -46,7 +46,6 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
@@ -157,7 +156,10 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
|
||||
return False
|
||||
|
||||
# For shared drive content, the root has id == driveId
|
||||
return bool(drive_id and folder_id == drive_id)
|
||||
if drive_id and folder_id == drive_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _public_access() -> ExternalAccess:
|
||||
@@ -614,16 +616,6 @@ class GoogleDriveConnector(
|
||||
# empty parents due to permission limitations)
|
||||
# Check shared drive root first (simple ID comparison)
|
||||
if _is_shared_drive_root(folder):
|
||||
# files().get() returns 'Drive' for shared drive roots;
|
||||
# fetch the real name via drives().get().
|
||||
# Try both the retriever and admin since the admin may
|
||||
# not have access to private shared drives.
|
||||
drive_name = self._get_shared_drive_name(
|
||||
current_id, file.user_email
|
||||
)
|
||||
if drive_name:
|
||||
node.display_name = drive_name
|
||||
node.node_type = HierarchyNodeType.SHARED_DRIVE
|
||||
reached_terminal = True
|
||||
break
|
||||
|
||||
@@ -699,15 +691,6 @@ class GoogleDriveConnector(
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
|
||||
"""Fetch the name of a shared drive, trying both the retriever and admin."""
|
||||
for email in {retriever_email, self.primary_admin_email}:
|
||||
svc = get_drive_service(self.creds, email)
|
||||
name = get_shared_drive_name(svc, drive_id)
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_drive_ids(self) -> set[str]:
|
||||
return self._get_all_drives_for_user(self.primary_admin_email)
|
||||
|
||||
|
||||
@@ -154,26 +154,6 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return HIERARCHY_FIELDS
|
||||
|
||||
|
||||
def get_shared_drive_name(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
) -> str | None:
|
||||
"""Fetch the actual name of a shared drive via the drives().get() API.
|
||||
|
||||
The files().get() API returns 'Drive' as the name for shared drive root
|
||||
folders. Only drives().get() returns the real user-assigned name.
|
||||
"""
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id, fields="name").execute()
|
||||
return drive.get("name")
|
||||
except HttpError as e:
|
||||
if e.resp.status in (403, 404):
|
||||
logger.debug(f"Cannot access drive {drive_id}: {e}")
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def get_external_access_for_folder(
|
||||
folder: GoogleDriveFileType,
|
||||
google_domain: str,
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Base Data Access Layer (DAL) for database operations.
|
||||
|
||||
The DAL pattern groups related database operations into cohesive classes
|
||||
with explicit session management. It supports two usage modes:
|
||||
|
||||
1. **External session** (FastAPI endpoints) — the caller provides a session
|
||||
whose lifecycle is managed by FastAPI's dependency injection.
|
||||
|
||||
2. **Self-managed session** (Celery tasks, scripts) — the DAL creates its
|
||||
own session via the tenant-aware session factory.
|
||||
|
||||
Subclasses add domain-specific query methods while inheriting session
|
||||
management. See ``ee.onyx.db.scim.ScimDAL`` for a concrete example.
|
||||
|
||||
Example (FastAPI)::
|
||||
|
||||
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
return ScimDAL(db_session)
|
||||
|
||||
@router.get("/users")
|
||||
def list_users(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
|
||||
return dal.list_user_mappings(...)
|
||||
|
||||
Example (Celery)::
|
||||
|
||||
with ScimDAL.from_tenant("tenant_abc") as dal:
|
||||
dal.create_user_mapping(...)
|
||||
dal.commit()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
|
||||
class DAL:
|
||||
"""Base Data Access Layer.
|
||||
|
||||
Holds a SQLAlchemy session and provides transaction control helpers.
|
||||
Subclasses add domain-specific query methods.
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: Session) -> None:
|
||||
self._session = db_session
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""Direct access to the underlying session for advanced use cases."""
|
||||
return self._session
|
||||
|
||||
def commit(self) -> None:
|
||||
self._session.commit()
|
||||
|
||||
def flush(self) -> None:
|
||||
self._session.flush()
|
||||
|
||||
def rollback(self) -> None:
|
||||
self._session.rollback()
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def from_tenant(cls, tenant_id: str) -> Generator["DAL", None, None]:
|
||||
"""Create a DAL with a self-managed session for the given tenant.
|
||||
|
||||
The session is automatically closed when the context manager exits.
|
||||
The caller must explicitly call ``commit()`` to persist changes.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
yield cls(session)
|
||||
@@ -430,7 +430,7 @@ def fetch_existing_models(
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
flow_type_filter: list[LLMModelFlowType],
|
||||
flow_types: list[LLMModelFlowType],
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
@@ -438,27 +438,30 @@ def fetch_existing_llm_providers(
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flow_type_filter: List of flow types to filter by, empty list for no filter
|
||||
flow_types: List of flow types to filter by
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
"""
|
||||
stmt = select(LLMProviderModel)
|
||||
|
||||
if flow_type_filter:
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
|
||||
.distinct()
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
)
|
||||
else:
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
| LLMProviderModel.id.in_(image_gen_provider_ids)
|
||||
)
|
||||
|
||||
stmt = stmt.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
@@ -794,15 +797,13 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
else:
|
||||
# Add new model - all models from GitHub config are visible
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
new_model = ModelConfiguration(
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_config.name,
|
||||
supported_flows=[LLMModelFlowType.CHAT],
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
name=model_config.name,
|
||||
display_name=model_config.display_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(new_model)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
|
||||
@@ -63,10 +63,6 @@ if TYPE_CHECKING:
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
|
||||
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
)
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
@@ -92,14 +88,6 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
blocked_model in normalized_model_name
|
||||
for blocked_model in _VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG
|
||||
)
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -278,11 +266,10 @@ class LitellmLLM(LLM):
|
||||
is_ollama = self._model_provider == LlmProviderNames.OLLAMA_CHAT
|
||||
is_mistral = self._model_provider == LlmProviderNames.MISTRAL
|
||||
is_vertex_ai = self._model_provider == LlmProviderNames.VERTEX_AI
|
||||
# Some Vertex Anthropic models reject output_config.
|
||||
# Keep this guard until LiteLLM/Vertex accept the field for these models.
|
||||
is_vertex_model_rejecting_output_config = (
|
||||
is_vertex_ai
|
||||
and _is_vertex_model_rejecting_output_config(self.config.model_name)
|
||||
# Vertex Anthropic Opus 4.5 rejects output_config.
|
||||
# Keep this guard until LiteLLM/Vertex accept the field for this model.
|
||||
is_vertex_opus_4_5 = (
|
||||
is_vertex_ai and "claude-opus-4-5" in self.config.model_name.lower()
|
||||
)
|
||||
|
||||
#########################
|
||||
@@ -314,7 +301,7 @@ class LitellmLLM(LLM):
|
||||
# Temperature
|
||||
temperature = 1 if is_reasoning else self._temperature
|
||||
|
||||
if stream and not is_vertex_model_rejecting_output_config:
|
||||
if stream and not is_vertex_opus_4_5:
|
||||
optional_kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
# Note, there is a reasoning_effort parameter in LiteLLM but it is completely jank and does not work for any
|
||||
@@ -323,7 +310,7 @@ class LitellmLLM(LLM):
|
||||
is_reasoning
|
||||
# The default of this parameter not set is surprisingly not the equivalent of an Auto but is actually Off
|
||||
and reasoning_effort != ReasoningEffort.OFF
|
||||
and not is_vertex_model_rejecting_output_config
|
||||
and not is_vertex_opus_4_5
|
||||
):
|
||||
if is_openai_model:
|
||||
# OpenAI API does not accept reasoning params for GPT 5 chat models
|
||||
|
||||
@@ -21,6 +21,7 @@ from fastapi.routing import APIRoute
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from starlette.types import Lifespan
|
||||
@@ -120,7 +121,6 @@ from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.pat.api import router as pat_router
|
||||
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
|
||||
from onyx.server.query_and_chat.chat_backend import router as chat_router
|
||||
from onyx.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@@ -563,8 +563,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_router_auth(application)
|
||||
|
||||
# Initialize and instrument the app with production Prometheus config
|
||||
setup_prometheus_metrics(application)
|
||||
# Initialize and instrument the app
|
||||
Instrumentator().instrument(application).expose(application)
|
||||
|
||||
use_route_function_names_as_operation_ids(application)
|
||||
|
||||
|
||||
@@ -102,9 +102,6 @@ def check_router_auth(
|
||||
current_cloud_superuser = fetch_ee_implementation_or_noop(
|
||||
"onyx.auth.users", "current_cloud_superuser"
|
||||
)
|
||||
verify_scim_token = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.scim.auth", "verify_scim_token"
|
||||
)
|
||||
|
||||
for route in application.routes:
|
||||
# explicitly marked as public
|
||||
@@ -128,7 +125,6 @@ def check_router_auth(
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
):
|
||||
found_auth = True
|
||||
break
|
||||
|
||||
@@ -4,9 +4,8 @@ This client runs `opencode acp` directly in the sandbox pod via kubernetes exec,
|
||||
using stdin/stdout for JSON-RPC communication. This bypasses the HTTP server
|
||||
and uses the native ACP subprocess protocol.
|
||||
|
||||
Each message creates an ephemeral client (start → resume_or_create_session →
|
||||
send_message → stop) to prevent concurrent processes from corrupting
|
||||
opencode's flat file session storage.
|
||||
This module includes comprehensive logging for debugging ACP communication.
|
||||
Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true.
|
||||
|
||||
Usage:
|
||||
client = ACPExecClient(
|
||||
@@ -14,14 +13,12 @@ Usage:
|
||||
namespace="onyx-sandboxes",
|
||||
)
|
||||
client.start(cwd="/workspace")
|
||||
session_id = client.resume_or_create_session(cwd="/workspace/sessions/abc")
|
||||
for event in client.send_message("What files are here?", session_id=session_id):
|
||||
for event in client.send_message("What files are here?"):
|
||||
print(event)
|
||||
client.stop()
|
||||
"""
|
||||
|
||||
import json
|
||||
import shlex
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@@ -30,7 +27,6 @@ from dataclasses import field
|
||||
from queue import Empty
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from acp.schema import AgentMessageChunk
|
||||
from acp.schema import AgentPlanUpdate
|
||||
@@ -44,7 +40,6 @@ from kubernetes import client # type: ignore
|
||||
from kubernetes import config
|
||||
from kubernetes.stream import stream as k8s_stream # type: ignore
|
||||
from kubernetes.stream.ws_client import WSClient # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.server.features.build.api.packet_logger import get_packet_logger
|
||||
@@ -105,7 +100,7 @@ class ACPClientState:
|
||||
"""Internal state for the ACP client."""
|
||||
|
||||
initialized: bool = False
|
||||
sessions: dict[str, ACPSession] = field(default_factory=dict)
|
||||
current_session: ACPSession | None = None
|
||||
next_request_id: int = 0
|
||||
agent_capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
agent_info: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -160,16 +155,16 @@ class ACPExecClient:
|
||||
self._k8s_client = client.CoreV1Api()
|
||||
return self._k8s_client
|
||||
|
||||
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> None:
|
||||
"""Start the agent process via exec and initialize the ACP connection.
|
||||
|
||||
Only performs the ACP `initialize` handshake. Sessions are created
|
||||
separately via `resume_or_create_session()`.
|
||||
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> str:
|
||||
"""Start the agent process via exec and initialize a session.
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the `opencode acp` process
|
||||
cwd: Working directory for the agent
|
||||
timeout: Timeout for initialization
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If startup fails
|
||||
"""
|
||||
@@ -178,19 +173,8 @@ class ACPExecClient:
|
||||
|
||||
k8s = self._get_k8s_client()
|
||||
|
||||
# Start opencode acp via exec.
|
||||
# Set XDG_DATA_HOME so opencode stores session data on the shared
|
||||
# workspace volume (accessible from file-sync container for snapshots)
|
||||
# instead of the container-local ~/.local/share/ filesystem.
|
||||
data_dir = shlex.quote(f"{cwd}/.opencode-data")
|
||||
safe_cwd = shlex.quote(cwd)
|
||||
exec_command = [
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
f"XDG_DATA_HOME={data_dir} exec opencode acp --cwd {safe_cwd}",
|
||||
]
|
||||
|
||||
logger.info(f"[ACP] Starting client: pod={self._pod_name} cwd={cwd}")
|
||||
# Start opencode acp via exec
|
||||
exec_command = ["opencode", "acp", "--cwd", cwd]
|
||||
|
||||
try:
|
||||
self._ws_client = k8s_stream(
|
||||
@@ -217,12 +201,15 @@ class ACPExecClient:
|
||||
# Give process a moment to start
|
||||
time.sleep(0.5)
|
||||
|
||||
# Initialize ACP connection (no session creation)
|
||||
# Initialize ACP connection
|
||||
self._initialize(timeout=timeout)
|
||||
|
||||
logger.info(f"[ACP] Client started: pod={self._pod_name}")
|
||||
# Create session
|
||||
session_id = self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ACP] Client start failed: pod={self._pod_name} error={e}")
|
||||
self.stop()
|
||||
raise RuntimeError(f"Failed to start ACP exec client: {e}") from e
|
||||
|
||||
@@ -237,52 +224,56 @@ class ACPExecClient:
|
||||
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
# Read available data
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stderr - log any agent errors
|
||||
stderr_data = self._ws_client.read_stderr(timeout=0.01)
|
||||
if stderr_data:
|
||||
logger.warning(
|
||||
f"[ACP] stderr pod={self._pod_name}: "
|
||||
f"{stderr_data.strip()[:500]}"
|
||||
)
|
||||
|
||||
# Read stdout
|
||||
# Read stdout (channel 1)
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
# Log the raw incoming message
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"JSONRPC-PARSE-ERROR-K8S",
|
||||
{
|
||||
"raw_line": line[:500],
|
||||
"error": "JSON decode failed",
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"[ACP] Invalid JSON from agent: "
|
||||
f"{line[:100]}"
|
||||
f"Invalid JSON from agent: {line[:100]}"
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"[ACP] WebSocket closed: pod={self._pod_name}")
|
||||
packet_logger.log_raw(
|
||||
"K8S-WEBSOCKET-CLOSED",
|
||||
{"pod": self._pod_name, "namespace": self._namespace},
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
logger.warning(f"[ACP] Reader error: {e}, pod={self._pod_name}")
|
||||
packet_logger.log_raw(
|
||||
"K8S-READER-ERROR",
|
||||
{"error": str(e), "pod": self._pod_name},
|
||||
)
|
||||
logger.debug(f"Reader error: {e}")
|
||||
break
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the exec session and clean up."""
|
||||
session_ids = list(self._state.sessions.keys())
|
||||
logger.info(
|
||||
f"[ACP] Stopping client: pod={self._pod_name} " f"sessions={session_ids}"
|
||||
)
|
||||
self._stop_reader.set()
|
||||
|
||||
if self._ws_client is not None:
|
||||
@@ -409,150 +400,42 @@ class ACPExecClient:
|
||||
if not session_id:
|
||||
raise RuntimeError("No session ID returned from session/new")
|
||||
|
||||
self._state.sessions[session_id] = ACPSession(session_id=session_id, cwd=cwd)
|
||||
logger.info(f"[ACP] Created session: acp_session={session_id} cwd={cwd}")
|
||||
self._state.current_session = ACPSession(session_id=session_id, cwd=cwd)
|
||||
|
||||
return session_id
|
||||
|
||||
def _list_sessions(self, cwd: str, timeout: float = 10.0) -> list[dict[str, Any]]:
|
||||
"""List available ACP sessions, filtered by working directory.
|
||||
|
||||
Returns:
|
||||
List of session info dicts with keys like 'sessionId', 'cwd', 'title'.
|
||||
Empty list if session/list is not supported or fails.
|
||||
"""
|
||||
try:
|
||||
request_id = self._send_request("session/list", {"cwd": cwd})
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
sessions = result.get("sessions", [])
|
||||
logger.info(f"[ACP] session/list: {len(sessions)} sessions for cwd={cwd}")
|
||||
return sessions
|
||||
except Exception as e:
|
||||
logger.info(f"[ACP] session/list unavailable: {e}")
|
||||
return []
|
||||
|
||||
def _resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume an existing ACP session.
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to resume
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the resume request
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume fails
|
||||
"""
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
"cwd": cwd,
|
||||
"mcpServers": [],
|
||||
}
|
||||
|
||||
request_id = self._send_request("session/resume", params)
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
|
||||
# The response should contain the session ID
|
||||
resumed_id = result.get("sessionId", session_id)
|
||||
self._state.sessions[resumed_id] = ACPSession(session_id=resumed_id, cwd=cwd)
|
||||
|
||||
logger.info(f"[ACP] Resumed session: acp_session={resumed_id} cwd={cwd}")
|
||||
return resumed_id
|
||||
|
||||
def _try_resume_existing_session(self, cwd: str, timeout: float) -> str | None:
|
||||
"""Try to find and resume an existing session for this workspace.
|
||||
|
||||
When multiple API server replicas connect to the same sandbox pod,
|
||||
a previous replica may have already created an ACP session for this
|
||||
workspace. This method discovers and resumes that session so the
|
||||
agent retains conversation context.
|
||||
|
||||
Args:
|
||||
cwd: Working directory to search for sessions
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The resumed session ID, or None if no session could be resumed
|
||||
"""
|
||||
# List sessions for this workspace directory
|
||||
sessions = self._list_sessions(cwd, timeout=min(timeout, 10.0))
|
||||
if not sessions:
|
||||
return None
|
||||
|
||||
# Pick the most recent session (first in list, assuming sorted)
|
||||
target = sessions[0]
|
||||
target_id = target.get("sessionId")
|
||||
if not target_id:
|
||||
logger.warning("[ACP] session/list returned session without sessionId")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Resuming existing session: acp_session={target_id} "
|
||||
f"(found {len(sessions)})"
|
||||
)
|
||||
|
||||
try:
|
||||
return self._resume_session(target_id, cwd, timeout)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ACP] session/resume failed for {target_id}: {e}, "
|
||||
f"falling back to session/new"
|
||||
)
|
||||
return None
|
||||
|
||||
def resume_or_create_session(self, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume a session from opencode's on-disk storage, or create a new one.
|
||||
|
||||
With ephemeral clients (one process per message), this always hits disk.
|
||||
Tries resume first to preserve conversation context, falls back to new.
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
|
||||
# Try to resume from opencode's persisted storage
|
||||
resumed_id = self._try_resume_existing_session(cwd, timeout)
|
||||
if resumed_id:
|
||||
return resumed_id
|
||||
|
||||
# Create a new session
|
||||
return self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
timeout: float = ACP_MESSAGE_TIMEOUT,
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message to a specific session and stream response events.
|
||||
"""Send a message and stream response events.
|
||||
|
||||
Args:
|
||||
message: The message content to send
|
||||
session_id: The ACP session ID to send the message to
|
||||
timeout: Maximum time to wait for complete response (defaults to ACP_MESSAGE_TIMEOUT env var)
|
||||
|
||||
Yields:
|
||||
Typed ACP schema event objects
|
||||
"""
|
||||
if session_id not in self._state.sessions:
|
||||
raise RuntimeError(
|
||||
f"Unknown session {session_id}. "
|
||||
f"Known sessions: {list(self._state.sessions.keys())}"
|
||||
)
|
||||
if self._state.current_session is None:
|
||||
raise RuntimeError("No active session. Call start() first.")
|
||||
|
||||
session_id = self._state.current_session.session_id
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Sending prompt: "
|
||||
f"acp_session={session_id} pod={self._pod_name} "
|
||||
f"queue_backlog={self._response_queue.qsize()}"
|
||||
# Log the start of message processing
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-START-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"pod": self._pod_name,
|
||||
"namespace": self._namespace,
|
||||
"message_preview": (
|
||||
message[:200] + "..." if len(message) > 200 else message
|
||||
),
|
||||
"timeout": timeout,
|
||||
},
|
||||
)
|
||||
|
||||
prompt_content = [{"type": "text", "text": message}]
|
||||
@@ -563,53 +446,44 @@ class ACPExecClient:
|
||||
|
||||
request_id = self._send_request("session/prompt", params)
|
||||
start_time = time.time()
|
||||
last_event_time = time.time()
|
||||
last_event_time = time.time() # Track time since last event for keepalive
|
||||
events_yielded = 0
|
||||
keepalive_count = 0
|
||||
completion_reason = "unknown"
|
||||
|
||||
while True:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
completion_reason = "timeout"
|
||||
logger.warning(
|
||||
f"[ACP] Prompt timeout: "
|
||||
f"acp_session={session_id} events={events_yielded}, "
|
||||
f"sending session/cancel"
|
||||
packet_logger.log_raw(
|
||||
"ACP-TIMEOUT-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"elapsed_ms": (time.time() - start_time) * 1000,
|
||||
},
|
||||
)
|
||||
try:
|
||||
self.cancel(session_id=session_id)
|
||||
except Exception as cancel_err:
|
||||
logger.warning(
|
||||
f"[ACP] session/cancel failed on timeout: {cancel_err}"
|
||||
)
|
||||
yield Error(code=-1, message="Timeout waiting for response")
|
||||
break
|
||||
|
||||
try:
|
||||
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
|
||||
last_event_time = time.time()
|
||||
last_event_time = time.time() # Reset keepalive timer on event
|
||||
except Empty:
|
||||
# Send SSE keepalive if idle
|
||||
# Check if we need to send an SSE keepalive
|
||||
idle_time = time.time() - last_event_time
|
||||
if idle_time >= SSE_KEEPALIVE_INTERVAL:
|
||||
keepalive_count += 1
|
||||
packet_logger.log_raw(
|
||||
"SSE-KEEPALIVE-YIELD",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"idle_seconds": idle_time,
|
||||
},
|
||||
)
|
||||
yield SSEKeepalive()
|
||||
last_event_time = time.time()
|
||||
last_event_time = time.time() # Reset after yielding keepalive
|
||||
continue
|
||||
|
||||
# Check for JSON-RPC response to our prompt request.
|
||||
msg_id = message_data.get("id")
|
||||
is_response = "method" not in message_data and (
|
||||
msg_id == request_id
|
||||
or (msg_id is not None and str(msg_id) == str(request_id))
|
||||
)
|
||||
if is_response:
|
||||
completion_reason = "jsonrpc_response"
|
||||
# Check for response to our prompt request
|
||||
if message_data.get("id") == request_id:
|
||||
if "error" in message_data:
|
||||
error_data = message_data["error"]
|
||||
completion_reason = "jsonrpc_error"
|
||||
logger.warning(f"[ACP] Prompt error: {error_data}")
|
||||
packet_logger.log_jsonrpc_response(
|
||||
request_id, error=error_data, context="k8s"
|
||||
)
|
||||
@@ -624,16 +498,26 @@ class ACPExecClient:
|
||||
)
|
||||
try:
|
||||
prompt_response = PromptResponse.model_validate(result)
|
||||
packet_logger.log_acp_event_yielded(
|
||||
"prompt_response", prompt_response
|
||||
)
|
||||
events_yielded += 1
|
||||
yield prompt_response
|
||||
except ValidationError as e:
|
||||
logger.error(f"[ACP] PromptResponse validation failed: {e}")
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"type": "prompt_response", "error": str(e)},
|
||||
)
|
||||
|
||||
# Log completion summary
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[ACP] Prompt complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-COMPLETE-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"events_yielded": events_yielded,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
@@ -642,29 +526,25 @@ class ACPExecClient:
|
||||
params_data = message_data.get("params", {})
|
||||
update = params_data.get("update", {})
|
||||
|
||||
prompt_complete = False
|
||||
# Log the notification
|
||||
packet_logger.log_jsonrpc_notification(
|
||||
"session/update",
|
||||
{"update_type": update.get("sessionUpdate")},
|
||||
context="k8s",
|
||||
)
|
||||
|
||||
for event in self._process_session_update(update):
|
||||
events_yielded += 1
|
||||
# Log each yielded event
|
||||
event_type = self._get_event_type_name(event)
|
||||
packet_logger.log_acp_event_yielded(event_type, event)
|
||||
yield event
|
||||
if isinstance(event, PromptResponse):
|
||||
prompt_complete = True
|
||||
break
|
||||
|
||||
if prompt_complete:
|
||||
completion_reason = "prompt_response_via_notification"
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[ACP] Prompt complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
)
|
||||
break
|
||||
|
||||
# Handle requests from agent - send error response
|
||||
elif "method" in message_data and "id" in message_data:
|
||||
logger.debug(
|
||||
f"[ACP] Unsupported agent request: "
|
||||
f"method={message_data['method']}"
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNSUPPORTED-REQUEST-K8S",
|
||||
{"method": message_data["method"], "id": message_data["id"]},
|
||||
)
|
||||
self._send_error_response(
|
||||
message_data["id"],
|
||||
@@ -672,49 +552,113 @@ class ACPExecClient:
|
||||
f"Method not supported: {message_data['method']}",
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"[ACP] Unhandled message: "
|
||||
f"id={message_data.get('id')} "
|
||||
f"method={message_data.get('method')} "
|
||||
f"keys={list(message_data.keys())}"
|
||||
)
|
||||
def _get_event_type_name(self, event: ACPEvent) -> str:
|
||||
"""Get the type name for an ACP event."""
|
||||
if isinstance(event, AgentMessageChunk):
|
||||
return "agent_message_chunk"
|
||||
elif isinstance(event, AgentThoughtChunk):
|
||||
return "agent_thought_chunk"
|
||||
elif isinstance(event, ToolCallStart):
|
||||
return "tool_call_start"
|
||||
elif isinstance(event, ToolCallProgress):
|
||||
return "tool_call_progress"
|
||||
elif isinstance(event, AgentPlanUpdate):
|
||||
return "agent_plan_update"
|
||||
elif isinstance(event, CurrentModeUpdate):
|
||||
return "current_mode_update"
|
||||
elif isinstance(event, PromptResponse):
|
||||
return "prompt_response"
|
||||
elif isinstance(event, Error):
|
||||
return "error"
|
||||
elif isinstance(event, SSEKeepalive):
|
||||
return "sse_keepalive"
|
||||
return "unknown"
|
||||
|
||||
def _process_session_update(
|
||||
self, update: dict[str, Any]
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Process a session/update notification and yield typed ACP schema objects."""
|
||||
update_type = update.get("sessionUpdate")
|
||||
if not isinstance(update_type, str):
|
||||
return
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
# Map update types to their ACP schema classes.
|
||||
# Note: prompt_response is included because ACP sometimes sends it as a
|
||||
# notification WITHOUT a corresponding JSON-RPC response. We accept
|
||||
# either signal as turn completion (first one wins).
|
||||
type_map: dict[str, type[BaseModel]] = {
|
||||
"agent_message_chunk": AgentMessageChunk,
|
||||
"agent_thought_chunk": AgentThoughtChunk,
|
||||
"tool_call": ToolCallStart,
|
||||
"tool_call_update": ToolCallProgress,
|
||||
"plan": AgentPlanUpdate,
|
||||
"current_mode_update": CurrentModeUpdate,
|
||||
"prompt_response": PromptResponse,
|
||||
}
|
||||
|
||||
model_class = type_map.get(update_type)
|
||||
if model_class is not None:
|
||||
if update_type == "agent_message_chunk":
|
||||
try:
|
||||
yield cast(ACPEvent, model_class.model_validate(update))
|
||||
yield AgentMessageChunk.model_validate(update)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"[ACP] Validation error for {update_type}: {e}")
|
||||
elif update_type not in (
|
||||
"user_message_chunk",
|
||||
"available_commands_update",
|
||||
"session_info_update",
|
||||
"usage_update",
|
||||
):
|
||||
logger.debug(f"[ACP] Unknown update type: {update_type}")
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "agent_thought_chunk":
|
||||
try:
|
||||
yield AgentThoughtChunk.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "user_message_chunk":
|
||||
# Echo of user message - skip but log
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "user_message_chunk"}
|
||||
)
|
||||
|
||||
elif update_type == "tool_call":
|
||||
try:
|
||||
yield ToolCallStart.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "tool_call_update":
|
||||
try:
|
||||
yield ToolCallProgress.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "plan":
|
||||
try:
|
||||
yield AgentPlanUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "current_mode_update":
|
||||
try:
|
||||
yield CurrentModeUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "available_commands_update":
|
||||
# Skip command updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "available_commands_update"}
|
||||
)
|
||||
|
||||
elif update_type == "session_info_update":
|
||||
# Skip session info updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "session_info_update"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Unknown update types are logged
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNKNOWN-UPDATE-TYPE-K8S",
|
||||
{"update_type": update_type, "update": update},
|
||||
)
|
||||
|
||||
def _send_error_response(self, request_id: int, code: int, message: str) -> None:
|
||||
"""Send an error response to an agent request."""
|
||||
@@ -729,24 +673,15 @@ class ACPExecClient:
|
||||
|
||||
self._ws_client.write_stdin(json.dumps(response) + "\n")
|
||||
|
||||
def cancel(self, session_id: str | None = None) -> None:
|
||||
"""Cancel the current operation on a session.
|
||||
def cancel(self) -> None:
|
||||
"""Cancel the current operation."""
|
||||
if self._state.current_session is None:
|
||||
return
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to cancel. If None, cancels all sessions.
|
||||
"""
|
||||
if session_id:
|
||||
if session_id in self._state.sessions:
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": session_id},
|
||||
)
|
||||
else:
|
||||
for sid in self._state.sessions:
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": sid},
|
||||
)
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": self._state.current_session.session_id},
|
||||
)
|
||||
|
||||
def health_check(self, timeout: float = 5.0) -> bool: # noqa: ARG002
|
||||
"""Check if we can exec into the pod."""
|
||||
@@ -772,6 +707,13 @@ class ACPExecClient:
|
||||
"""Check if the exec session is running."""
|
||||
return self._ws_client is not None and self._ws_client.is_open()
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
"""Get the current session ID, if any."""
|
||||
if self._state.current_session:
|
||||
return self._state.current_session.session_id
|
||||
return None
|
||||
|
||||
def __enter__(self) -> "ACPExecClient":
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
@@ -50,7 +50,6 @@ from pathlib import Path
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from acp.schema import PromptResponse
|
||||
from kubernetes import client # type: ignore
|
||||
from kubernetes import config
|
||||
from kubernetes.client.rest import ApiException # type: ignore
|
||||
@@ -98,10 +97,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# API server pod hostname — used to identify which replica is handling a request.
|
||||
# In K8s, HOSTNAME is set to the pod name (e.g., "api-server-dpgg7").
|
||||
_API_SERVER_HOSTNAME = os.environ.get("HOSTNAME", "unknown")
|
||||
|
||||
# Constants for pod configuration
|
||||
# Note: Next.js ports are dynamically allocated from SANDBOX_NEXTJS_PORT_START to
|
||||
# SANDBOX_NEXTJS_PORT_END range, with one port per session.
|
||||
@@ -1161,9 +1156,7 @@ done
|
||||
def terminate(self, sandbox_id: UUID) -> None:
|
||||
"""Terminate a sandbox and clean up Kubernetes resources.
|
||||
|
||||
Removes session mappings for this sandbox, then deletes the
|
||||
Service and Pod. ACP clients are ephemeral (created per message),
|
||||
so there's nothing to stop here.
|
||||
Deletes the Service and Pod for the sandbox.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to terminate
|
||||
@@ -1402,8 +1395,7 @@ echo "Session workspace setup complete"
|
||||
) -> None:
|
||||
"""Clean up a session workspace (on session delete).
|
||||
|
||||
Removes the ACP session mapping and executes kubectl exec to remove
|
||||
the session directory. The shared ACP client persists for other sessions.
|
||||
Executes kubectl exec to remove the session directory.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1472,7 +1464,6 @@ echo "Session cleanup complete"
|
||||
the snapshot and upload to S3. Captures:
|
||||
- sessions/$session_id/outputs/ (generated artifacts, web apps)
|
||||
- sessions/$session_id/attachments/ (user uploaded files)
|
||||
- sessions/$session_id/.opencode-data/ (opencode session data for resumption)
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1497,10 +1488,9 @@ echo "Session cleanup complete"
|
||||
f"{session_id_str}/{snapshot_id}.tar.gz"
|
||||
)
|
||||
|
||||
# Create tar and upload to S3 via file-sync container.
|
||||
# .opencode-data/ is already on the shared workspace volume because we set
|
||||
# XDG_DATA_HOME to the session directory when starting opencode (see
|
||||
# ACPExecClient.start()). No cross-container copy needed.
|
||||
# Exec into pod to create and upload snapshot (outputs + attachments)
|
||||
# Uses s5cmd pipe to stream tar.gz directly to S3
|
||||
# Only snapshot if outputs/ exists. Include attachments/ only if non-empty.
|
||||
exec_command = [
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
@@ -1513,7 +1503,6 @@ if [ ! -d outputs ]; then
|
||||
fi
|
||||
dirs="outputs"
|
||||
[ -d attachments ] && [ "$(ls -A attachments 2>/dev/null)" ] && dirs="$dirs attachments"
|
||||
[ -d .opencode-data ] && [ "$(ls -A .opencode-data 2>/dev/null)" ] && dirs="$dirs .opencode-data"
|
||||
tar -czf - $dirs | /s5cmd pipe {s3_path}
|
||||
echo "SNAPSHOT_CREATED"
|
||||
""",
|
||||
@@ -1635,7 +1624,6 @@ echo "SNAPSHOT_CREATED"
|
||||
Steps:
|
||||
1. Exec s5cmd cat in file-sync container to stream snapshot from S3
|
||||
2. Pipe directly to tar for extraction in the shared workspace volume
|
||||
(.opencode-data/ is restored automatically since XDG_DATA_HOME points here)
|
||||
3. Regenerate configuration files (AGENTS.md, opencode.json, files symlink)
|
||||
4. Start the NextJS dev server
|
||||
|
||||
@@ -1819,41 +1807,6 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
return exec_client.health_check(timeout=timeout)
|
||||
|
||||
def _create_ephemeral_acp_client(
|
||||
self, sandbox_id: UUID, session_path: str
|
||||
) -> ACPExecClient:
|
||||
"""Create a new ephemeral ACP client for a single message exchange.
|
||||
|
||||
Each call starts a fresh `opencode acp` process in the sandbox pod.
|
||||
The process is short-lived — stopped after the message completes.
|
||||
This prevents the bug where multiple long-lived processes (one per
|
||||
API replica) operate on the same session's flat file storage
|
||||
concurrently, causing the JSON-RPC response to be silently lost.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_path: Working directory for the session (e.g. /workspace/sessions/{id}).
|
||||
XDG_DATA_HOME is set relative to this so opencode's session data
|
||||
lives inside the snapshot directory.
|
||||
|
||||
Returns:
|
||||
A running ACPExecClient (caller must stop it when done)
|
||||
"""
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
acp_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
acp_client.start(cwd=session_path)
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Created ephemeral ACP client: "
|
||||
f"sandbox={sandbox_id} pod={pod_name} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
return acp_client
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
@@ -1862,12 +1815,8 @@ echo "Session config regeneration complete"
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message to the CLI agent and stream ACP events.
|
||||
|
||||
Creates an ephemeral `opencode acp` process for each message.
|
||||
The process resumes the session from opencode's on-disk storage,
|
||||
handles the prompt, then is stopped. This ensures only one process
|
||||
operates on a session's flat files at a time, preventing the bug
|
||||
where multiple long-lived processes (one per API replica) corrupt
|
||||
each other's in-memory state.
|
||||
Runs `opencode acp` via kubectl exec in the sandbox pod.
|
||||
The agent runs in the session-specific workspace.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1878,103 +1827,67 @@ echo "Session config regeneration complete"
|
||||
Typed ACP schema event objects
|
||||
"""
|
||||
packet_logger = get_packet_logger()
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
# Create an ephemeral ACP client for this message
|
||||
acp_client = self._create_ephemeral_acp_client(sandbox_id, session_path)
|
||||
# Log ACP client creation
|
||||
packet_logger.log_acp_client_start(
|
||||
sandbox_id, session_id, session_path, context="k8s"
|
||||
)
|
||||
|
||||
exec_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
|
||||
# Log the send_message call at sandbox manager level
|
||||
packet_logger.log_session_start(session_id, sandbox_id, message)
|
||||
|
||||
events_count = 0
|
||||
try:
|
||||
# Resume (or create) the ACP session from opencode's on-disk storage
|
||||
acp_session_id = acp_client.resume_or_create_session(cwd=session_path)
|
||||
exec_client.start(cwd=session_path)
|
||||
for event in exec_client.send_message(message):
|
||||
events_count += 1
|
||||
yield event
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Sending message: "
|
||||
f"session={session_id} acp_session={acp_session_id} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
# Log successful completion
|
||||
packet_logger.log_session_end(
|
||||
session_id, success=True, events_count=events_count
|
||||
)
|
||||
|
||||
# Log the send_message call at sandbox manager level
|
||||
packet_logger.log_session_start(session_id, sandbox_id, message)
|
||||
|
||||
events_count = 0
|
||||
got_prompt_response = False
|
||||
try:
|
||||
for event in acp_client.send_message(
|
||||
message, session_id=acp_session_id
|
||||
):
|
||||
events_count += 1
|
||||
if isinstance(event, PromptResponse):
|
||||
got_prompt_response = True
|
||||
yield event
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] send_message completed: "
|
||||
f"session={session_id} events={events_count} "
|
||||
f"got_prompt_response={got_prompt_response}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id, success=True, events_count=events_count
|
||||
)
|
||||
except GeneratorExit:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] GeneratorExit: session={session_id} "
|
||||
f"events={events_count}, sending session/cancel"
|
||||
)
|
||||
try:
|
||||
acp_client.cancel(session_id=acp_session_id)
|
||||
except Exception as cancel_err:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] session/cancel failed on GeneratorExit: "
|
||||
f"{cancel_err}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error="GeneratorExit: Client disconnected or stream closed by consumer",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] Exception: session={session_id} "
|
||||
f"events={events_count} error={e}, sending session/cancel"
|
||||
)
|
||||
try:
|
||||
acp_client.cancel(session_id=acp_session_id)
|
||||
except Exception as cancel_err:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] session/cancel failed on Exception: "
|
||||
f"{cancel_err}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"Exception: {str(e)}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] {type(e).__name__}: session={session_id} "
|
||||
f"error={e}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"{type(e).__name__}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
except GeneratorExit:
|
||||
# Generator was closed by consumer (client disconnect, timeout, broken pipe)
|
||||
# This is the most common failure mode for SSE streaming
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error="GeneratorExit: Client disconnected or stream closed by consumer",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log failure from normal exceptions
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"Exception: {str(e)}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
except BaseException as e:
|
||||
# Log failure from other base exceptions (SystemExit, KeyboardInterrupt, etc.)
|
||||
exception_type = type(e).__name__
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"{exception_type}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Always stop the ephemeral ACP client to kill the opencode process.
|
||||
# This ensures no stale processes linger in the sandbox container.
|
||||
try:
|
||||
acp_client.stop()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] Failed to stop ephemeral ACP client: "
|
||||
f"session={session_id} error={e}"
|
||||
)
|
||||
exec_client.stop()
|
||||
# Log client stop
|
||||
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
|
||||
|
||||
def list_directory(
|
||||
self, sandbox_id: UUID, session_id: UUID, path: str
|
||||
|
||||
@@ -30,28 +30,17 @@ OPENSEARCH_NOT_ENABLED_MESSAGE = (
|
||||
"OpenSearch indexing must be enabled to use this feature."
|
||||
)
|
||||
|
||||
MIGRATION_STATUS_MESSAGE = (
|
||||
"Our records indicate that the transition to OpenSearch is still in progress. "
|
||||
"OpenSearch retrieval is necessary to use this feature. "
|
||||
"You can still use Document Sets, though! "
|
||||
"If you would like to manually switch to OpenSearch, "
|
||||
'Go to the "Document Index Migration" section in the Admin panel.'
|
||||
)
|
||||
|
||||
router = APIRouter(prefix=HIERARCHY_NODES_PREFIX)
|
||||
|
||||
|
||||
def _require_opensearch(db_session: Session) -> None:
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX or not get_opensearch_retrieval_state(
|
||||
db_session
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=OPENSEARCH_NOT_ENABLED_MESSAGE,
|
||||
)
|
||||
if not get_opensearch_retrieval_state(db_session):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=MIGRATION_STATUS_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _get_user_access_info(
|
||||
|
||||
@@ -8,7 +8,6 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
@@ -57,7 +56,7 @@ def is_version_gte(v1: str, v2: str) -> bool:
|
||||
|
||||
|
||||
def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry]:
|
||||
"""Parse MDX content into ReleaseNoteEntry objects."""
|
||||
"""Parse MDX content into ReleaseNoteEntry objects for versions >= __version__."""
|
||||
all_entries = []
|
||||
|
||||
update_pattern = (
|
||||
@@ -83,12 +82,6 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
if not all_entries:
|
||||
raise ValueError("Could not parse any release note entries from MDX.")
|
||||
|
||||
if INSTANCE_TYPE == "cloud":
|
||||
# Cloud often runs ahead of docs release tags; always notify on latest release.
|
||||
return sorted(
|
||||
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
|
||||
)[:1]
|
||||
|
||||
# Filter to valid versions >= __version__
|
||||
if __version__ and is_valid_version(__version__):
|
||||
entries = [
|
||||
|
||||
@@ -310,7 +310,7 @@ def list_llm_providers(
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(
|
||||
db_session=db_session,
|
||||
flow_type_filter=[],
|
||||
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
|
||||
exclude_image_generation_providers=not include_image_gen,
|
||||
):
|
||||
from_model_start = datetime.now(timezone.utc)
|
||||
@@ -568,7 +568,9 @@ def list_llm_provider_basics(
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch user-accessible LLM providers")
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session, [])
|
||||
all_providers = fetch_existing_llm_providers(
|
||||
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
|
||||
)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Prometheus instrumentation for the Onyx API server.
|
||||
|
||||
Provides a production-grade metrics configuration with:
|
||||
|
||||
- Exact HTTP status codes (no grouping into 2xx/3xx)
|
||||
- In-progress request gauge broken down by handler and method
|
||||
- Custom latency histogram buckets tuned for API workloads
|
||||
- Request/response size tracking
|
||||
- Slow request counter with configurable threshold
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
from starlette.applications import Starlette
|
||||
|
||||
SLOW_REQUEST_THRESHOLD_SECONDS: float = float(
|
||||
os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")
|
||||
)
|
||||
|
||||
_EXCLUDED_HANDLERS = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/openapi.json",
|
||||
]
|
||||
|
||||
_slow_requests = Counter(
|
||||
"onyx_api_slow_requests_total",
|
||||
"Total requests exceeding the slow request threshold",
|
||||
["method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def _slow_request_callback(info: Info) -> None:
|
||||
"""Increment slow request counter when duration exceeds threshold."""
|
||||
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
|
||||
_slow_requests.labels(
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
|
||||
|
||||
def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
"""Configure and attach Prometheus instrumentation to the FastAPI app.
|
||||
|
||||
Records exact status codes, tracks in-progress requests per handler,
|
||||
and counts slow requests exceeding a configurable threshold.
|
||||
"""
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=False,
|
||||
should_group_untemplated=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
instrumentator.add(_slow_request_callback)
|
||||
|
||||
instrumentator.instrument(app).expose(app)
|
||||
@@ -349,7 +349,6 @@ def get_chat_session(
|
||||
shared_status=chat_session.shared_status,
|
||||
current_temperature_override=chat_session.temperature_override,
|
||||
deleted=chat_session.deleted,
|
||||
owner_name=chat_session.user.personal_name if chat_session.user else None,
|
||||
# Packets are now directly serialized as Packet Pydantic models
|
||||
packets=replay_packet_lists,
|
||||
)
|
||||
|
||||
@@ -224,7 +224,6 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
current_alternate_model: str | None
|
||||
current_temperature_override: float | None
|
||||
deleted: bool = False
|
||||
owner_name: str | None = None
|
||||
packets: list[list[Packet]]
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.tools.tool_implementations.web_search.utils import (
|
||||
from onyx.tools.tool_implementations.web_search.utils import MAX_CHARS_PER_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.url import normalize_url as normalize_web_content_url
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -802,9 +801,7 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
for url in all_urls:
|
||||
doc_id = url_to_doc_id.get(url)
|
||||
indexed_section = indexed_by_doc_id.get(doc_id) if doc_id else None
|
||||
# WebContent.link is normalized (query/fragment stripped). Match on the
|
||||
# same normalized form to avoid dropping successful crawl results.
|
||||
crawled_section = crawled_by_url.get(normalize_web_content_url(url))
|
||||
crawled_section = crawled_by_url.get(url)
|
||||
|
||||
if indexed_section and indexed_section.combined_content:
|
||||
# Prefer indexed
|
||||
|
||||
@@ -2,7 +2,6 @@ import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from dataclasses import replace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
@@ -135,25 +134,25 @@ EXPECTED_SHARED_DRIVE_1_HIERARCHY = ExpectedHierarchyNode(
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=RESTRICTED_ACCESS_FOLDER_ID,
|
||||
display_name="restricted_access",
|
||||
display_name="restricted_access_folder",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_1_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_1_ID,
|
||||
display_name="folder 1",
|
||||
display_name="folder_1",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_1_ID,
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_1_1_ID,
|
||||
display_name="folder 1-1",
|
||||
display_name="folder_1_1",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_1_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_1_2_ID,
|
||||
display_name="folder 1-2",
|
||||
display_name="folder_1_2",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_1_ID,
|
||||
),
|
||||
@@ -171,25 +170,25 @@ EXPECTED_SHARED_DRIVE_2_HIERARCHY = ExpectedHierarchyNode(
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=SECTIONS_FOLDER_ID,
|
||||
display_name="sections",
|
||||
display_name="sections_folder",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_2_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_2_ID,
|
||||
display_name="folder 2",
|
||||
display_name="folder_2",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_2_ID,
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_2_1_ID,
|
||||
display_name="folder 2-1",
|
||||
display_name="folder_2_1",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_2_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_2_2_ID,
|
||||
display_name="folder 2-2",
|
||||
display_name="folder_2_2",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_2_ID,
|
||||
),
|
||||
@@ -209,23 +208,27 @@ def flatten_hierarchy(
|
||||
return result
|
||||
|
||||
|
||||
def _node(
|
||||
raw_node_id: str,
|
||||
display_name: str,
|
||||
node_type: HierarchyNodeType,
|
||||
raw_parent_id: str | None = None,
|
||||
) -> ExpectedHierarchyNode:
|
||||
return ExpectedHierarchyNode(
|
||||
raw_node_id=raw_node_id,
|
||||
display_name=display_name,
|
||||
node_type=node_type,
|
||||
raw_parent_id=raw_parent_id,
|
||||
)
|
||||
|
||||
|
||||
# Flattened maps for easy lookup
|
||||
EXPECTED_SHARED_DRIVE_1_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_1_HIERARCHY)
|
||||
EXPECTED_SHARED_DRIVE_2_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_2_HIERARCHY)
|
||||
ALL_EXPECTED_SHARED_DRIVE_NODES = {
|
||||
**EXPECTED_SHARED_DRIVE_1_NODES,
|
||||
**EXPECTED_SHARED_DRIVE_2_NODES,
|
||||
}
|
||||
|
||||
# Map of folder ID to its expected parent ID
|
||||
EXPECTED_PARENT_MAPPING: dict[str, str | None] = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
RESTRICTED_ACCESS_FOLDER_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
SHARED_DRIVE_2_ID: None,
|
||||
SECTIONS_FOLDER_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
}
|
||||
|
||||
EXTERNAL_SHARED_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
|
||||
@@ -283,7 +286,7 @@ TEST_USER_1_MY_DRIVE_FOLDER_ID = (
|
||||
)
|
||||
|
||||
TEST_USER_1_DRIVE_B_ID = (
|
||||
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuper_private
|
||||
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuuuuper_private
|
||||
)
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID = (
|
||||
"1oIj7nigzvP5xI2F8BmibUA8R_J3AbBA-" # Child folder (silliness)
|
||||
@@ -322,106 +325,6 @@ PERM_SYNC_DRIVE_ACCESS_MAPPING: dict[str, set[str]] = {
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: {ADMIN_EMAIL, TEST_USER_1_EMAIL},
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# NON-SHARED-DRIVE HIERARCHY NODES
|
||||
# ============================================================================
|
||||
# These cover My Drive roots, perm sync drives, extra shared drives,
|
||||
# and standalone folders that appear in various tests.
|
||||
# Display names must match what the Google Drive API actually returns.
|
||||
# ============================================================================
|
||||
|
||||
EXPECTED_FOLDER_3 = _node(
|
||||
FOLDER_3_ID, "Folder 3", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
|
||||
)
|
||||
|
||||
EXPECTED_ADMIN_MY_DRIVE = _node(ADMIN_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER)
|
||||
EXPECTED_TEST_USER_1_MY_DRIVE = _node(
|
||||
TEST_USER_1_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
|
||||
)
|
||||
EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER = _node(
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
"partial_sharing",
|
||||
HierarchyNodeType.FOLDER,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
)
|
||||
EXPECTED_TEST_USER_2_MY_DRIVE = _node(
|
||||
TEST_USER_2_MY_DRIVE, "My Drive", HierarchyNodeType.FOLDER
|
||||
)
|
||||
EXPECTED_TEST_USER_3_MY_DRIVE = _node(
|
||||
TEST_USER_3_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
|
||||
)
|
||||
|
||||
EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY = _node(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
"perm_sync_drive_0dc9d8b5-e243-4c2f-8678-2235958f7d7c",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A = _node(
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
"perm_sync_drive_785db121-0823-4ebe-8689-ad7f52405e32",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B = _node(
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
"perm_sync_drive_d8dc3649-3f65-4392-b87f-4b20e0389673",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
|
||||
EXPECTED_TEST_USER_1_DRIVE_B = _node(
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
"My_super_special_shared_drive_suuuper_private",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_TEST_USER_1_DRIVE_B_FOLDER = _node(
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
"silliness",
|
||||
HierarchyNodeType.FOLDER,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
)
|
||||
EXPECTED_TEST_USER_1_EXTRA_DRIVE_1 = _node(
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
"Okay_Admin_fine_I_will_share",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_TEST_USER_1_EXTRA_DRIVE_2 = _node(
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID, "reee test", HierarchyNodeType.SHARED_DRIVE
|
||||
)
|
||||
EXPECTED_TEST_USER_1_EXTRA_FOLDER = _node(
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
"read only no download test",
|
||||
HierarchyNodeType.FOLDER,
|
||||
)
|
||||
|
||||
EXPECTED_PILL_FOLDER = _node(
|
||||
PILL_FOLDER_ID, "pill_folder", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
|
||||
)
|
||||
EXPECTED_EXTERNAL_SHARED_FOLDER = _node(
|
||||
EXTERNAL_SHARED_FOLDER_ID, "Onyx-test", HierarchyNodeType.FOLDER
|
||||
)
|
||||
|
||||
# Comprehensive mapping of ALL known hierarchy nodes.
|
||||
# Every retrieved node is checked against this for display_name and node_type.
|
||||
ALL_EXPECTED_HIERARCHY_NODES: dict[str, ExpectedHierarchyNode] = {
|
||||
**EXPECTED_SHARED_DRIVE_1_NODES,
|
||||
**EXPECTED_SHARED_DRIVE_2_NODES,
|
||||
FOLDER_3_ID: EXPECTED_FOLDER_3,
|
||||
ADMIN_MY_DRIVE_ID: EXPECTED_ADMIN_MY_DRIVE,
|
||||
TEST_USER_1_MY_DRIVE_ID: EXPECTED_TEST_USER_1_MY_DRIVE,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID: EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER,
|
||||
TEST_USER_2_MY_DRIVE: EXPECTED_TEST_USER_2_MY_DRIVE,
|
||||
TEST_USER_3_MY_DRIVE_ID: EXPECTED_TEST_USER_3_MY_DRIVE,
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B,
|
||||
TEST_USER_1_DRIVE_B_ID: EXPECTED_TEST_USER_1_DRIVE_B,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID: EXPECTED_TEST_USER_1_DRIVE_B_FOLDER,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_1,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_2,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID: EXPECTED_TEST_USER_1_EXTRA_FOLDER,
|
||||
PILL_FOLDER_ID: EXPECTED_PILL_FOLDER,
|
||||
EXTERNAL_SHARED_FOLDER_ID: EXPECTED_EXTERNAL_SHARED_FOLDER,
|
||||
}
|
||||
|
||||
# Dictionary for access permissions
|
||||
# All users have access to their own My Drive as well as public files
|
||||
ACCESS_MAPPING: dict[str, list[int]] = {
|
||||
@@ -605,29 +508,28 @@ def load_connector_outputs(
|
||||
|
||||
def assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes: list[HierarchyNode],
|
||||
expected_nodes: dict[str, ExpectedHierarchyNode],
|
||||
expected_node_ids: set[str],
|
||||
expected_parent_mapping: dict[str, str | None] | None = None,
|
||||
ignorable_node_ids: set[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Assert that retrieved hierarchy nodes match expected structure.
|
||||
|
||||
Checks node IDs, display names, node types, and parent relationships
|
||||
for EVERY retrieved node (global checks).
|
||||
|
||||
Args:
|
||||
retrieved_nodes: List of HierarchyNode objects from the connector
|
||||
expected_nodes: Dict mapping raw_node_id -> ExpectedHierarchyNode with
|
||||
expected display_name, node_type, and raw_parent_id
|
||||
ignorable_node_ids: Optional set of node IDs that can be missing or extra
|
||||
without failing. Useful for non-deterministically returned nodes.
|
||||
expected_node_ids: Set of expected raw_node_ids
|
||||
expected_parent_mapping: Optional dict mapping node_id -> parent_id for parent verification
|
||||
ignorable_node_ids: Optional set of node IDs that can be missing or extra without failing.
|
||||
Useful for nodes that are non-deterministically returned by the connector.
|
||||
"""
|
||||
expected_node_ids = set(expected_nodes.keys())
|
||||
retrieved_node_ids = {node.raw_node_id for node in retrieved_nodes}
|
||||
ignorable = ignorable_node_ids or set()
|
||||
|
||||
# Calculate differences, excluding ignorable nodes
|
||||
missing = expected_node_ids - retrieved_node_ids - ignorable
|
||||
extra = retrieved_node_ids - expected_node_ids - ignorable
|
||||
|
||||
# Print discrepancies for debugging
|
||||
if missing or extra:
|
||||
print("Expected hierarchy node IDs:")
|
||||
print(sorted(expected_node_ids))
|
||||
@@ -641,146 +543,181 @@ def assert_hierarchy_nodes_match_expected(
|
||||
print("Ignorable node IDs:")
|
||||
print(sorted(ignorable))
|
||||
|
||||
assert (
|
||||
not missing and not extra
|
||||
), f"Hierarchy node mismatch. Missing: {missing}, Extra: {extra}"
|
||||
assert not missing and not extra, (
|
||||
f"Hierarchy node mismatch. " f"Missing: {missing}, " f"Extra: {extra}"
|
||||
)
|
||||
|
||||
for node in retrieved_nodes:
|
||||
if node.raw_node_id in ignorable and node.raw_node_id not in expected_nodes:
|
||||
continue
|
||||
|
||||
assert (
|
||||
node.raw_node_id in expected_nodes
|
||||
), f"Node {node.raw_node_id} ({node.display_name}) not found in expected_nodes"
|
||||
expected = expected_nodes[node.raw_node_id]
|
||||
|
||||
assert node.display_name == expected.display_name, (
|
||||
f"Display name mismatch for node {node.raw_node_id}: "
|
||||
f"expected '{expected.display_name}', got '{node.display_name}'"
|
||||
)
|
||||
assert node.node_type == expected.node_type, (
|
||||
f"Node type mismatch for node {node.raw_node_id}: "
|
||||
f"expected '{expected.node_type}', got '{node.node_type}'"
|
||||
)
|
||||
if expected.raw_parent_id is not None:
|
||||
assert node.raw_parent_id == expected.raw_parent_id, (
|
||||
# Verify parent relationships if provided
|
||||
if expected_parent_mapping is not None:
|
||||
for node in retrieved_nodes:
|
||||
if node.raw_node_id not in expected_parent_mapping:
|
||||
continue
|
||||
expected_parent = expected_parent_mapping[node.raw_node_id]
|
||||
assert node.raw_parent_id == expected_parent, (
|
||||
f"Parent mismatch for node {node.raw_node_id} ({node.display_name}): "
|
||||
f"expected parent={expected.raw_parent_id}, got parent={node.raw_parent_id}"
|
||||
f"expected parent={expected_parent}, got parent={node.raw_parent_id}"
|
||||
)
|
||||
|
||||
|
||||
def _pick(
|
||||
*node_ids: str,
|
||||
) -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Pick nodes from ALL_EXPECTED_HIERARCHY_NODES by their IDs."""
|
||||
return {nid: ALL_EXPECTED_HIERARCHY_NODES[nid] for nid in node_ids}
|
||||
|
||||
|
||||
def _clear_parents(
|
||||
nodes: dict[str, ExpectedHierarchyNode],
|
||||
*node_ids: str,
|
||||
) -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Return a shallow copy of nodes with the specified nodes' parents set to None.
|
||||
Useful for OAuth tests where the user can't resolve certain parents
|
||||
(e.g. a folder in another user's My Drive)."""
|
||||
result = dict(nodes)
|
||||
for nid in node_ids:
|
||||
result[nid] = replace(result[nid], raw_parent_id=None)
|
||||
return result
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1: bool = True,
|
||||
include_drive_2: bool = True,
|
||||
include_restricted_folder: bool = True,
|
||||
) -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Get expected hierarchy nodes for shared drives."""
|
||||
result: dict[str, ExpectedHierarchyNode] = {}
|
||||
) -> tuple[set[str], dict[str, str | None]]:
|
||||
"""
|
||||
Get expected hierarchy node IDs and parent mapping for shared drives.
|
||||
|
||||
Returns:
|
||||
Tuple of (expected_node_ids, expected_parent_mapping)
|
||||
"""
|
||||
expected_ids: set[str] = set()
|
||||
expected_parents: dict[str, str | None] = {}
|
||||
|
||||
if include_drive_1:
|
||||
result.update(EXPECTED_SHARED_DRIVE_1_NODES)
|
||||
if not include_restricted_folder:
|
||||
result.pop(RESTRICTED_ACCESS_FOLDER_ID, None)
|
||||
expected_ids.add(SHARED_DRIVE_1_ID)
|
||||
expected_parents[SHARED_DRIVE_1_ID] = None
|
||||
|
||||
if include_restricted_folder:
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_parents[RESTRICTED_ACCESS_FOLDER_ID] = SHARED_DRIVE_1_ID
|
||||
|
||||
expected_ids.add(FOLDER_1_ID)
|
||||
expected_parents[FOLDER_1_ID] = SHARED_DRIVE_1_ID
|
||||
|
||||
expected_ids.add(FOLDER_1_1_ID)
|
||||
expected_parents[FOLDER_1_1_ID] = FOLDER_1_ID
|
||||
|
||||
expected_ids.add(FOLDER_1_2_ID)
|
||||
expected_parents[FOLDER_1_2_ID] = FOLDER_1_ID
|
||||
|
||||
if include_drive_2:
|
||||
result.update(EXPECTED_SHARED_DRIVE_2_NODES)
|
||||
expected_ids.add(SHARED_DRIVE_2_ID)
|
||||
expected_parents[SHARED_DRIVE_2_ID] = None
|
||||
|
||||
return result
|
||||
expected_ids.add(SECTIONS_FOLDER_ID)
|
||||
expected_parents[SECTIONS_FOLDER_ID] = SHARED_DRIVE_2_ID
|
||||
|
||||
expected_ids.add(FOLDER_2_ID)
|
||||
expected_parents[FOLDER_2_ID] = SHARED_DRIVE_2_ID
|
||||
|
||||
expected_ids.add(FOLDER_2_1_ID)
|
||||
expected_parents[FOLDER_2_1_ID] = FOLDER_2_ID
|
||||
|
||||
expected_ids.add(FOLDER_2_2_ID)
|
||||
expected_parents[FOLDER_2_2_ID] = FOLDER_2_ID
|
||||
|
||||
return expected_ids, expected_parents
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_folder_1() -> dict[str, ExpectedHierarchyNode]:
|
||||
def get_expected_hierarchy_for_folder_1() -> tuple[set[str], dict[str, str | None]]:
|
||||
"""Get expected hierarchy for folder_1 and its children only."""
|
||||
return _pick(FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID)
|
||||
return (
|
||||
{FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID},
|
||||
{
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_folder_2() -> dict[str, ExpectedHierarchyNode]:
|
||||
def get_expected_hierarchy_for_folder_2() -> tuple[set[str], dict[str, str | None]]:
|
||||
"""Get expected hierarchy for folder_2 and its children only."""
|
||||
return _pick(FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID)
|
||||
return (
|
||||
{FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID},
|
||||
{
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1() -> dict[str, ExpectedHierarchyNode]:
|
||||
def get_expected_hierarchy_for_test_user_1() -> tuple[set[str], dict[str, str | None]]:
|
||||
"""
|
||||
Get expected hierarchy for test_user_1's full access (OAuth).
|
||||
Get expected hierarchy for test_user_1's full access.
|
||||
|
||||
test_user_1 has access to:
|
||||
- shared_drive_1 and its contents (folder_1, folder_1_1, folder_1_2)
|
||||
- folder_3 (shared from admin's My Drive)
|
||||
- PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A and PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B
|
||||
- Additional drives/folders the user has access to
|
||||
|
||||
NOTE: Folder 3 lives in the admin's My Drive. When running as an OAuth
|
||||
connector for test_user_1, the Google Drive API won't return the parent
|
||||
for Folder 3 because the user can't access the admin's My Drive root.
|
||||
"""
|
||||
result = get_expected_hierarchy_for_shared_drives(
|
||||
# Start with shared_drive_1 hierarchy
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=False,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
result.update(
|
||||
_pick(
|
||||
FOLDER_3_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
)
|
||||
)
|
||||
return _clear_parents(result, FOLDER_3_ID)
|
||||
|
||||
# folder_3 is shared from admin's My Drive
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
# Perm sync drives that test_user_1 has access to
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID] = None
|
||||
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID] = None
|
||||
|
||||
# Additional drives/folders test_user_1 has access to
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_parents[TEST_USER_1_MY_DRIVE_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_parents[TEST_USER_1_MY_DRIVE_FOLDER_ID] = TEST_USER_1_MY_DRIVE_ID
|
||||
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_parents[TEST_USER_1_DRIVE_B_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_parents[TEST_USER_1_DRIVE_B_FOLDER_ID] = TEST_USER_1_DRIVE_B_ID
|
||||
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_parents[TEST_USER_1_EXTRA_DRIVE_1_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_parents[TEST_USER_1_EXTRA_DRIVE_2_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
# Parent unknown, skip adding to expected_parents
|
||||
|
||||
return expected_ids, expected_parents
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1_shared_drives_only() -> (
|
||||
dict[str, ExpectedHierarchyNode]
|
||||
tuple[set[str], dict[str, str | None]]
|
||||
):
|
||||
"""Expected hierarchy nodes when test_user_1 runs with include_shared_drives=True only."""
|
||||
result = get_expected_hierarchy_for_test_user_1()
|
||||
for nid in (
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
):
|
||||
result.pop(nid, None)
|
||||
return result
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
|
||||
|
||||
# This mode should not include My Drive roots/folders.
|
||||
expected_ids.discard(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_ids.discard(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
|
||||
# don't include shared with me
|
||||
expected_ids.discard(FOLDER_3_ID)
|
||||
expected_ids.discard(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
|
||||
return expected_ids, expected_parents
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1_shared_with_me_only() -> (
|
||||
dict[str, ExpectedHierarchyNode]
|
||||
tuple[set[str], dict[str, str | None]]
|
||||
):
|
||||
"""Expected hierarchy nodes when test_user_1 runs with include_files_shared_with_me=True only."""
|
||||
return _clear_parents(
|
||||
_pick(FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID),
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
expected_ids: set[str] = {FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID}
|
||||
expected_parents: dict[str, str | None] = {}
|
||||
return expected_ids, expected_parents
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1_my_drive_only() -> (
|
||||
dict[str, ExpectedHierarchyNode]
|
||||
tuple[set[str], dict[str, str | None]]
|
||||
):
|
||||
"""Expected hierarchy nodes when test_user_1 runs with include_my_drives=True only."""
|
||||
return _pick(TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_ids: set[str] = {TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID}
|
||||
expected_parents: dict[str, str | None] = {
|
||||
TEST_USER_1_MY_DRIVE_ID: None,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID: TEST_USER_1_MY_DRIVE_ID,
|
||||
}
|
||||
return expected_ids, expected_parents
|
||||
|
||||
@@ -3,11 +3,12 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
@@ -15,15 +16,21 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_hierarchy_nodes_match_expected,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
@@ -40,15 +47,18 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import PILL_FOLDER_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
PILL_FOLDER_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
)
|
||||
@@ -80,6 +90,7 @@ def test_include_all(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# Should get everything in shared and admin's My Drive with oauth
|
||||
expected_file_ids = (
|
||||
ADMIN_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
@@ -98,28 +109,33 @@ def test_include_all(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify hierarchy nodes for shared drives
|
||||
# When include_shared_drives=True, we get ALL shared drives the admin has access to
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
# Restricted folder may not always be retrieved due to access limitations
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional shared drives that admin has access to
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(ADMIN_MY_DRIVE_ID)
|
||||
expected_ids.add(PILL_FOLDER_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
|
||||
# My Drive folders
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -144,6 +160,7 @@ def test_include_shared_drives_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# Should only get shared drives
|
||||
expected_file_ids = (
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
+ FOLDER_1_FILE_IDS
|
||||
@@ -160,24 +177,26 @@ def test_include_shared_drives_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify hierarchy nodes - should include both shared drives and their folders
|
||||
# When include_shared_drives=True, we get ALL shared drives admin has access to
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional shared drives that admin has access to
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
)
|
||||
|
||||
|
||||
@@ -201,21 +220,24 @@ def test_include_my_drives_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# Should only get primary_admins My Drive because we are impersonating them
|
||||
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = _pick(
|
||||
# Verify hierarchy nodes - My Drive should yield folder_3 as a hierarchy node
|
||||
# Also includes admin's My Drive root and folders shared with admin
|
||||
expected_ids = {
|
||||
FOLDER_3_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
)
|
||||
}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -251,14 +273,17 @@ def test_drive_one_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=False,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
# Restricted folder is non-deterministically returned by the connector
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -299,15 +324,33 @@ def test_folder_and_shared_drive(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
|
||||
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
}
|
||||
expected_parents = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
SHARED_DRIVE_2_ID: None,
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
}
|
||||
# Restricted folder is non-deterministically returned
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -327,6 +370,7 @@ def test_folders_only(
|
||||
FOLDER_2_2_URL,
|
||||
FOLDER_3_URL,
|
||||
]
|
||||
# This should get converted to a drive request and spit out a warning in the logs
|
||||
shared_drive_urls = [
|
||||
FOLDER_1_1_URL,
|
||||
]
|
||||
@@ -353,16 +397,23 @@ def test_folders_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
|
||||
# Verify hierarchy nodes - specific folders requested plus their parent nodes
|
||||
# The connector walks up the hierarchy to include parent drives/folders
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
FOLDER_3_ID,
|
||||
}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -395,8 +446,9 @@ def test_personal_folders_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = _pick(FOLDER_3_ID, ADMIN_MY_DRIVE_ID)
|
||||
# Verify hierarchy nodes - folder_3 and its parent (admin's My Drive root)
|
||||
expected_ids = {FOLDER_3_ID, ADMIN_MY_DRIVE_ID}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
)
|
||||
|
||||
@@ -14,10 +14,11 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_hierarchy_nodes_match_expected,
|
||||
)
|
||||
@@ -261,35 +262,37 @@ def test_gdrive_perm_sync_with_real_data(
|
||||
hierarchy_connector = _build_connector(google_drive_service_acct_connector_factory)
|
||||
output = load_connector_outputs(hierarchy_connector, include_permissions=True)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify the expected shared drives hierarchy
|
||||
# When include_shared_drives=True and include_my_drives=True, we get ALL drives
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
TEST_USER_2_MY_DRIVE,
|
||||
TEST_USER_3_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
EXTERNAL_SHARED_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional shared drives in the organization
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(ADMIN_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_2_MY_DRIVE)
|
||||
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
|
||||
expected_ids.add(PILL_FOLDER_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
|
||||
@@ -4,11 +4,12 @@ from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
@@ -28,15 +29,21 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
@@ -67,10 +74,11 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
RESTRICTED_ACCESS_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
)
|
||||
@@ -148,35 +156,39 @@ def test_include_all(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify hierarchy nodes for shared drives
|
||||
# When include_shared_drives=True, we get ALL shared drives in the organization
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
TEST_USER_2_MY_DRIVE,
|
||||
TEST_USER_3_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
EXTERNAL_SHARED_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional shared drives in the organization
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(ADMIN_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_2_MY_DRIVE)
|
||||
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
|
||||
expected_ids.add(PILL_FOLDER_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
|
||||
|
||||
# My Drive folders
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -282,26 +294,28 @@ def test_include_shared_drives_only(
|
||||
# TODO: switch to 54 when restricted access issue is resolved
|
||||
assert len(output.documents) == 51 or len(output.documents) == 52
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify hierarchy nodes - should include both shared drives and their folders
|
||||
# When include_shared_drives=True, we get ALL shared drives in the organization
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional shared drives in the organization
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -339,7 +353,9 @@ def test_include_my_drives_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = _pick(
|
||||
# Verify hierarchy nodes - My Drive roots and folders for all users
|
||||
# Service account impersonates all users, so it sees all My Drives
|
||||
expected_ids = {
|
||||
FOLDER_3_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
@@ -349,10 +365,10 @@ def test_include_my_drives_only(
|
||||
PILL_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
EXTERNAL_SHARED_FOLDER_ID,
|
||||
)
|
||||
}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -389,14 +405,17 @@ def test_drive_one_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=False,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
# Restricted folder is non-deterministically returned
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -438,15 +457,33 @@ def test_folder_and_shared_drive(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
|
||||
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
}
|
||||
expected_parents = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
SHARED_DRIVE_2_ID: None,
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
}
|
||||
# Restricted folder is non-deterministically returned
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -493,16 +530,23 @@ def test_folders_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
|
||||
# Verify hierarchy nodes - specific folders requested plus their parent nodes
|
||||
# The connector walks up the hierarchy to include parent drives/folders
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
FOLDER_3_ID,
|
||||
}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=expected_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _clear_parents
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
@@ -53,6 +51,8 @@ def _check_for_error(
|
||||
retrieved_failures = output.failures
|
||||
assert len(retrieved_failures) <= 1
|
||||
|
||||
# current behavior is to fail silently for 403s; leaving this here for when we revert
|
||||
# if all 403s get fixed
|
||||
if len(retrieved_failures) == 1:
|
||||
fail_msg = retrieved_failures[0].failure_message
|
||||
assert "HttpError 403" in fail_msg
|
||||
@@ -83,11 +83,14 @@ def test_all(
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from my drive
|
||||
TEST_USER_1_FILE_IDS
|
||||
# These are the files from shared drives
|
||||
+ SHARED_DRIVE_1_FILE_IDS
|
||||
+ FOLDER_1_FILE_IDS
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
# These are the files shared with me from admin
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
@@ -99,9 +102,13 @@ def test_all(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - test_user_1 has access to shared_drive_1, folder_3,
|
||||
# perm sync drives, and additional drives/folders
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1(),
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
)
|
||||
|
||||
|
||||
@@ -126,6 +133,7 @@ def test_shared_drives_only(
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from shared drives
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
+ FOLDER_1_FILE_IDS
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
@@ -138,9 +146,14 @@ def test_shared_drives_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - test_user_1 sees multiple shared drives/folders
|
||||
expected_ids, expected_parents = (
|
||||
get_expected_hierarchy_for_test_user_1_shared_drives_only()
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1_shared_drives_only(),
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
)
|
||||
|
||||
|
||||
@@ -164,15 +177,24 @@ def test_shared_with_me_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS + list(range(0, 2))
|
||||
expected_file_ids = (
|
||||
# These are the files shared with me from admin
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - shared-with-me folders
|
||||
expected_ids, expected_parents = (
|
||||
get_expected_hierarchy_for_test_user_1_shared_with_me_only()
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1_shared_with_me_only(),
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
)
|
||||
|
||||
|
||||
@@ -196,15 +218,21 @@ def test_my_drive_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# These are the files from my drive
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - My Drive root + its folder(s)
|
||||
expected_ids, expected_parents = (
|
||||
get_expected_hierarchy_for_test_user_1_my_drive_only()
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1_my_drive_only(),
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
)
|
||||
|
||||
|
||||
@@ -228,15 +256,20 @@ def test_shared_my_drive_folder(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
expected_file_ids = (
|
||||
# this is a folder from admin's drive that is shared with me
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - only folder_3
|
||||
expected_ids = {FOLDER_3_ID}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=_clear_parents(_pick(FOLDER_3_ID), FOLDER_3_ID),
|
||||
expected_node_ids=expected_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -266,9 +299,16 @@ def test_shared_drive_folder(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - includes shared drive root + folder_1 subtree
|
||||
expected_ids = {SHARED_DRIVE_1_ID, FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID}
|
||||
expected_parents: dict[str, str | None] = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_nodes=_pick(
|
||||
SHARED_DRIVE_1_ID, FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID
|
||||
),
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
)
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Fixtures for testing DAL classes against a real PostgreSQL database.
|
||||
|
||||
These fixtures build on the db_session and tenant_context fixtures from
|
||||
the parent conftest (tests/external_dependency_unit/conftest.py).
|
||||
|
||||
Requires a running Postgres instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scim_dal(db_session: Session) -> ScimDAL:
|
||||
"""A ScimDAL backed by the real test database session."""
|
||||
return ScimDAL(db_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scim_token_factory(
|
||||
db_session: Session,
|
||||
) -> Generator[Callable[..., ScimToken], None, None]:
|
||||
"""Factory that creates ScimToken rows and cleans them up after the test."""
|
||||
created_ids: list[int] = []
|
||||
|
||||
def _create(
|
||||
name: str = "test-token",
|
||||
hashed_token: str | None = None,
|
||||
token_display: str = "onyx_scim_****test",
|
||||
created_by_id: UUID | None = None,
|
||||
) -> ScimToken:
|
||||
token = ScimToken(
|
||||
name=name,
|
||||
hashed_token=hashed_token or uuid4().hex,
|
||||
token_display=token_display,
|
||||
created_by_id=created_by_id or uuid4(),
|
||||
)
|
||||
db_session.add(token)
|
||||
db_session.flush()
|
||||
created_ids.append(token.id)
|
||||
return token
|
||||
|
||||
yield _create
|
||||
|
||||
for token_id in created_ids:
|
||||
obj = db_session.get(ScimToken, token_id)
|
||||
if obj:
|
||||
db_session.delete(obj)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_group_factory(
|
||||
db_session: Session,
|
||||
) -> Generator[Callable[..., UserGroup], None, None]:
|
||||
"""Factory that creates UserGroup rows for testing group mappings."""
|
||||
created_ids: list[int] = []
|
||||
|
||||
def _create(name: str | None = None) -> UserGroup:
|
||||
group = UserGroup(name=name or f"test-group-{uuid4().hex[:8]}")
|
||||
db_session.add(group)
|
||||
db_session.flush()
|
||||
created_ids.append(group.id)
|
||||
return group
|
||||
|
||||
yield _create
|
||||
|
||||
for group_id in created_ids:
|
||||
obj = db_session.get(UserGroup, group_id)
|
||||
if obj:
|
||||
db_session.delete(obj)
|
||||
db_session.commit()
|
||||
@@ -553,7 +553,7 @@ class TestDefaultProviderEndpoint:
|
||||
|
||||
try:
|
||||
existing_providers = fetch_existing_llm_providers(
|
||||
db_session, flow_type_filter=[LLMModelFlowType.CHAT]
|
||||
db_session, flow_types=[LLMModelFlowType.CHAT]
|
||||
)
|
||||
provider_names_to_restore: list[str] = []
|
||||
|
||||
|
||||
@@ -14,12 +14,9 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
@@ -609,95 +606,3 @@ class TestAutoModeSyncFeature:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_1_name)
|
||||
_cleanup_provider(db_session, provider_2_name)
|
||||
|
||||
|
||||
class TestAutoModeMissingFlows:
|
||||
"""Regression test: sync_auto_mode_models must create LLMModelFlow rows
|
||||
for every ModelConfiguration it inserts, otherwise the provider vanishes
|
||||
from listing queries that join through LLMModelFlow."""
|
||||
|
||||
def test_sync_auto_mode_creates_flow_rows(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Steps:
|
||||
1. Create a provider with no model configs (empty shell).
|
||||
2. Call sync_auto_mode_models to add models from a mock config.
|
||||
3. Assert every new ModelConfiguration has at least one LLMModelFlow.
|
||||
4. Assert fetch_existing_llm_providers (which joins through
|
||||
LLMModelFlow) returns the provider.
|
||||
"""
|
||||
mock_recommendations = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create provider with no model configs
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Run sync_auto_mode_models (simulating the periodic sync)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=mock_recommendations,
|
||||
)
|
||||
|
||||
# Step 3: Every ModelConfiguration must have at least one LLMModelFlow
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
synced_model_names = {mc.name for mc in provider.model_configurations}
|
||||
assert "gpt-4o" in synced_model_names
|
||||
assert "gpt-4o-mini" in synced_model_names
|
||||
|
||||
for mc in provider.model_configurations:
|
||||
assert len(mc.llm_model_flows) > 0, (
|
||||
f"ModelConfiguration '{mc.name}' (id={mc.id}) has no "
|
||||
f"LLMModelFlow rows — it will be invisible to listing queries"
|
||||
)
|
||||
|
||||
flow_types = {f.llm_model_flow_type for f in mc.llm_model_flows}
|
||||
assert (
|
||||
LLMModelFlowType.CHAT in flow_types
|
||||
), f"ModelConfiguration '{mc.name}' is missing a CHAT flow"
|
||||
|
||||
# Step 4: The provider must appear in fetch_existing_llm_providers
|
||||
listed_providers = fetch_existing_llm_providers(
|
||||
db_session=db_session,
|
||||
flow_type_filter=[LLMModelFlowType.CHAT],
|
||||
)
|
||||
listed_provider_names = {p.name for p in listed_providers}
|
||||
assert provider_name in listed_provider_names, (
|
||||
f"Provider '{provider_name}' not returned by "
|
||||
f"fetch_existing_llm_providers — models are missing flow rows"
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -86,16 +86,12 @@ class TestApplyLicenseStatusToSettings:
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.refresh_license_cache", return_value=None)
|
||||
@patch("ee.onyx.server.settings.api.get_session_with_current_tenant")
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_no_license_with_ee_flag_gates_access(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
_mock_get_session: MagicMock,
|
||||
_mock_refresh: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""No license + ENTERPRISE_EDITION_ENABLED=true → GATED_ACCESS."""
|
||||
@@ -111,16 +107,12 @@ class TestApplyLicenseStatusToSettings:
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.refresh_license_cache", return_value=None)
|
||||
@patch("ee.onyx.server.settings.api.get_session_with_current_tenant")
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_no_license_without_ee_flag_allows_community(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
_mock_get_session: MagicMock,
|
||||
_mock_refresh: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""No license + ENTERPRISE_EDITION_ENABLED=false → community mode (no gating)."""
|
||||
|
||||
@@ -996,114 +996,6 @@ class TestFallbackToolExtraction:
|
||||
assert result.tool_calls[0].tool_args == {"queries": ["beta"]}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=5)
|
||||
|
||||
def test_extracts_xml_style_invoke_from_answer_when_required(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer=(
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">'
|
||||
'["Onyx documentation", "Onyx docs", "Onyx platform"]'
|
||||
"</parameter></invoke></function_calls>"
|
||||
),
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=7,
|
||||
)
|
||||
|
||||
assert attempted is True
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].tool_name == "internal_search"
|
||||
assert result.tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs", "Onyx platform"]
|
||||
}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=7)
|
||||
|
||||
def test_extracts_xml_style_invoke_from_answer_when_auto(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
# Runtime-faithful shape: filtered answer is empty, raw answer has XML payload.
|
||||
answer=None,
|
||||
raw_answer=(
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">'
|
||||
'["Onyx documentation", "Onyx docs", "Onyx internal docs"]'
|
||||
"</parameter></invoke></function_calls>"
|
||||
),
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=9,
|
||||
)
|
||||
|
||||
assert attempted is True
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].tool_name == "internal_search"
|
||||
assert result.tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs", "Onyx internal docs"]
|
||||
}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=9)
|
||||
|
||||
def test_extracts_from_raw_answer_when_filtered_answer_has_no_xml(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer="",
|
||||
raw_answer=(
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">'
|
||||
'["Onyx documentation", "Onyx docs"]'
|
||||
"</parameter></invoke></function_calls>"
|
||||
),
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=10,
|
||||
)
|
||||
|
||||
assert attempted is True
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].tool_name == "internal_search"
|
||||
assert result.tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs"]
|
||||
}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=10)
|
||||
|
||||
def test_does_not_attempt_fallback_for_auto_without_tool_call_hints(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer="Here is a normal answer with no tool call payload.",
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=2,
|
||||
)
|
||||
|
||||
assert result is llm_step_result
|
||||
assert attempted is False
|
||||
|
||||
def test_returns_unchanged_when_required_but_nothing_extractable(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning="Need more info.",
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"""Tests for llm_step.py, specifically sanitization and argument parsing."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.llm_step import _extract_tool_call_kickoffs
|
||||
from onyx.chat.llm_step import _increment_turns
|
||||
from onyx.chat.llm_step import _parse_tool_args_to_dict
|
||||
from onyx.chat.llm_step import _resolve_tool_arguments
|
||||
from onyx.chat.llm_step import _sanitize_llm_output
|
||||
from onyx.chat.llm_step import _XmlToolCallContentFilter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
@@ -217,204 +211,3 @@ class TestExtractToolCallsFromResponseText:
|
||||
{"queries": ["alpha"]},
|
||||
{"queries": ["alpha"]},
|
||||
]
|
||||
|
||||
def test_extracts_xml_style_invoke_tool_call(self) -> None:
|
||||
response_text = """
|
||||
<function_calls>
|
||||
<invoke name="internal_search">
|
||||
<parameter name="queries" string="false">["Onyx documentation", "Onyx docs", "Onyx platform"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_definitions=self._tool_defs(),
|
||||
placement=self._placement(),
|
||||
)
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "internal_search"
|
||||
assert tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs", "Onyx platform"]
|
||||
}
|
||||
|
||||
def test_ignores_unknown_tool_in_xml_style_invoke(self) -> None:
|
||||
response_text = """
|
||||
<function_calls>
|
||||
<invoke name="unknown_tool">
|
||||
<parameter name="queries" string="false">["Onyx docs"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_definitions=self._tool_defs(),
|
||||
placement=self._placement(),
|
||||
)
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
class TestExtractToolCallKickoffs:
|
||||
"""Tests for the _extract_tool_call_kickoffs function."""
|
||||
|
||||
def test_valid_tool_call(self) -> None:
|
||||
tool_call_map = {
|
||||
0: {
|
||||
"id": "call_123",
|
||||
"name": "internal_search",
|
||||
"arguments": '{"queries": ["test"]}',
|
||||
}
|
||||
}
|
||||
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
|
||||
assert len(result) == 1
|
||||
assert result[0].tool_name == "internal_search"
|
||||
assert result[0].tool_args == {"queries": ["test"]}
|
||||
|
||||
def test_invalid_json_arguments_returns_empty_dict(self) -> None:
|
||||
"""Verify that malformed JSON arguments produce an empty dict
|
||||
rather than raising an exception. This confirms the dead try/except
|
||||
around _parse_tool_args_to_dict was safe to remove."""
|
||||
tool_call_map = {
|
||||
0: {
|
||||
"id": "call_bad",
|
||||
"name": "internal_search",
|
||||
"arguments": "not valid json {{{",
|
||||
}
|
||||
}
|
||||
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
|
||||
assert len(result) == 1
|
||||
assert result[0].tool_args == {}
|
||||
|
||||
def test_none_arguments_returns_empty_dict(self) -> None:
|
||||
tool_call_map = {
|
||||
0: {
|
||||
"id": "call_none",
|
||||
"name": "internal_search",
|
||||
"arguments": None,
|
||||
}
|
||||
}
|
||||
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
|
||||
assert len(result) == 1
|
||||
assert result[0].tool_args == {}
|
||||
|
||||
def test_skips_entries_missing_id_or_name(self) -> None:
|
||||
tool_call_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": None, "name": "internal_search", "arguments": "{}"},
|
||||
1: {"id": "call_1", "name": None, "arguments": "{}"},
|
||||
2: {"id": "call_2", "name": "internal_search", "arguments": "{}"},
|
||||
}
|
||||
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
|
||||
assert len(result) == 1
|
||||
assert result[0].tool_call_id == "call_2"
|
||||
|
||||
def test_tab_index_auto_increments(self) -> None:
|
||||
tool_call_map = {
|
||||
0: {"id": "c1", "name": "tool_a", "arguments": "{}"},
|
||||
1: {"id": "c2", "name": "tool_b", "arguments": "{}"},
|
||||
}
|
||||
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
|
||||
assert result[0].placement.tab_index == 0
|
||||
assert result[1].placement.tab_index == 1
|
||||
|
||||
def test_tab_index_override(self) -> None:
|
||||
tool_call_map = {
|
||||
0: {"id": "c1", "name": "tool_a", "arguments": "{}"},
|
||||
1: {"id": "c2", "name": "tool_b", "arguments": "{}"},
|
||||
}
|
||||
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0, tab_index=5)
|
||||
assert result[0].placement.tab_index == 5
|
||||
assert result[1].placement.tab_index == 5
|
||||
|
||||
|
||||
class TestXmlToolCallContentFilter:
|
||||
def test_strips_function_calls_block_single_chunk(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
output = f.process(
|
||||
"prefix "
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">["Onyx docs"]</parameter>'
|
||||
"</invoke></function_calls> suffix"
|
||||
)
|
||||
output += f.flush()
|
||||
assert output == "prefix suffix"
|
||||
|
||||
def test_strips_function_calls_block_split_across_chunks(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
chunks = [
|
||||
"Start ",
|
||||
"<function_",
|
||||
'calls><invoke name="internal_search">',
|
||||
'<parameter name="queries" string="false">["Onyx docs"]',
|
||||
"</parameter></invoke></function_calls>",
|
||||
" End",
|
||||
]
|
||||
output = "".join(f.process(chunk) for chunk in chunks) + f.flush()
|
||||
assert output == "Start End"
|
||||
|
||||
def test_preserves_non_tool_call_xml(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
output = f.process("A <tag>value</tag> B")
|
||||
output += f.flush()
|
||||
assert output == "A <tag>value</tag> B"
|
||||
|
||||
def test_does_not_strip_similar_tag_names(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
output = f.process(
|
||||
"A <function_calls_v2><invoke>noop</invoke></function_calls_v2> B"
|
||||
)
|
||||
output += f.flush()
|
||||
assert (
|
||||
output == "A <function_calls_v2><invoke>noop</invoke></function_calls_v2> B"
|
||||
)
|
||||
|
||||
|
||||
class TestIncrementTurns:
|
||||
"""Tests for the _increment_turns helper used by _close_reasoning_if_active."""
|
||||
|
||||
def test_increments_turn_index_when_no_sub_turn(self) -> None:
|
||||
turn, sub = _increment_turns(0, None)
|
||||
assert turn == 1
|
||||
assert sub is None
|
||||
|
||||
def test_increments_sub_turn_when_present(self) -> None:
|
||||
turn, sub = _increment_turns(3, 0)
|
||||
assert turn == 3
|
||||
assert sub == 1
|
||||
|
||||
def test_increments_sub_turn_from_nonzero(self) -> None:
|
||||
turn, sub = _increment_turns(5, 2)
|
||||
assert turn == 5
|
||||
assert sub == 3
|
||||
|
||||
|
||||
class TestResolveToolArguments:
|
||||
"""Tests for the _resolve_tool_arguments helper."""
|
||||
|
||||
def test_dict_arguments(self) -> None:
|
||||
obj = {"arguments": {"queries": ["test"]}}
|
||||
assert _resolve_tool_arguments(obj) == {"queries": ["test"]}
|
||||
|
||||
def test_dict_parameters(self) -> None:
|
||||
"""Falls back to 'parameters' key when 'arguments' is missing."""
|
||||
obj = {"parameters": {"queries": ["test"]}}
|
||||
assert _resolve_tool_arguments(obj) == {"queries": ["test"]}
|
||||
|
||||
def test_arguments_takes_precedence_over_parameters(self) -> None:
|
||||
obj = {"arguments": {"a": 1}, "parameters": {"b": 2}}
|
||||
assert _resolve_tool_arguments(obj) == {"a": 1}
|
||||
|
||||
def test_json_string_arguments(self) -> None:
|
||||
obj = {"arguments": '{"queries": ["test"]}'}
|
||||
assert _resolve_tool_arguments(obj) == {"queries": ["test"]}
|
||||
|
||||
def test_invalid_json_string_returns_empty_dict(self) -> None:
|
||||
obj = {"arguments": "not valid json"}
|
||||
assert _resolve_tool_arguments(obj) == {}
|
||||
|
||||
def test_no_arguments_or_parameters_returns_empty_dict(self) -> None:
|
||||
obj = {"name": "some_tool"}
|
||||
assert _resolve_tool_arguments(obj) == {}
|
||||
|
||||
def test_non_dict_non_string_arguments_returns_none(self) -> None:
|
||||
"""When arguments resolves to a list or int, returns None."""
|
||||
assert _resolve_tool_arguments({"arguments": [1, 2, 3]}) is None
|
||||
assert _resolve_tool_arguments({"arguments": 42}) is None
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Fixtures for unit-testing DAL classes with mocked sessions."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
|
||||
|
||||
def model_attrs(obj: object) -> dict[str, Any]:
|
||||
"""Extract user-set attributes from a SQLAlchemy model instance.
|
||||
|
||||
Filters out SQLAlchemy internal state (``_sa_instance_state``).
|
||||
Use this in tests to assert the full set of fields on a model object
|
||||
so that adding a new field forces the test to be updated.
|
||||
"""
|
||||
return {k: v for k, v in vars(obj).items() if not k.startswith("_")}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
"""A MagicMock standing in for a SQLAlchemy Session."""
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scim_dal(mock_db_session: MagicMock) -> ScimDAL:
|
||||
"""A ScimDAL backed by a mock session."""
|
||||
return ScimDAL(mock_db_session)
|
||||
@@ -1,110 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.dal import DAL
|
||||
|
||||
|
||||
class TestDALSessionDelegation:
|
||||
"""Verify that DAL methods delegate correctly to the underlying session."""
|
||||
|
||||
def test_commit(self) -> None:
|
||||
session = MagicMock()
|
||||
dal = DAL(session)
|
||||
dal.commit()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_flush(self) -> None:
|
||||
session = MagicMock()
|
||||
dal = DAL(session)
|
||||
dal.flush()
|
||||
session.flush.assert_called_once()
|
||||
|
||||
def test_rollback(self) -> None:
|
||||
session = MagicMock()
|
||||
dal = DAL(session)
|
||||
dal.rollback()
|
||||
session.rollback.assert_called_once()
|
||||
|
||||
def test_session_property_exposes_underlying_session(self) -> None:
|
||||
session = MagicMock()
|
||||
dal = DAL(session)
|
||||
assert dal.session is session
|
||||
|
||||
def test_commit_propagates_exception(self) -> None:
|
||||
session = MagicMock()
|
||||
session.commit.side_effect = RuntimeError("db error")
|
||||
dal = DAL(session)
|
||||
with pytest.raises(RuntimeError, match="db error"):
|
||||
dal.commit()
|
||||
|
||||
|
||||
class TestDALFromTenant:
|
||||
"""Verify the from_tenant context manager lifecycle."""
|
||||
|
||||
@patch("onyx.db.dal.get_session_with_tenant")
|
||||
def test_yields_dal_with_tenant_session(self, mock_get_session: MagicMock) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with DAL.from_tenant("tenant_abc") as dal:
|
||||
assert isinstance(dal, DAL)
|
||||
assert dal.session is mock_session
|
||||
|
||||
mock_get_session.assert_called_once_with(tenant_id="tenant_abc")
|
||||
|
||||
@patch("onyx.db.dal.get_session_with_tenant")
|
||||
def test_session_closed_after_context_exits(
|
||||
self, mock_get_session: MagicMock
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with DAL.from_tenant("tenant_abc"):
|
||||
pass
|
||||
|
||||
mock_get_session.return_value.__exit__.assert_called_once()
|
||||
|
||||
@patch("onyx.db.dal.get_session_with_tenant")
|
||||
def test_session_closed_on_exception(self, mock_get_session: MagicMock) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with DAL.from_tenant("tenant_abc"):
|
||||
raise ValueError("something broke")
|
||||
|
||||
mock_get_session.return_value.__exit__.assert_called_once()
|
||||
|
||||
@patch("onyx.db.dal.get_session_with_tenant")
|
||||
def test_subclass_from_tenant_returns_subclass_instance(
|
||||
self, mock_get_session: MagicMock
|
||||
) -> None:
|
||||
"""from_tenant uses cls(), so subclasses should get their own type back."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
class MyDAL(DAL):
|
||||
pass
|
||||
|
||||
with MyDAL.from_tenant("tenant_abc") as dal:
|
||||
assert isinstance(dal, MyDAL)
|
||||
|
||||
@patch("onyx.db.dal.get_session_with_tenant")
|
||||
def test_uncommitted_changes_not_auto_committed(
|
||||
self, mock_get_session: MagicMock
|
||||
) -> None:
|
||||
"""Exiting the context manager should NOT auto-commit."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with DAL.from_tenant("tenant_abc"):
|
||||
pass
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
@@ -1,185 +0,0 @@
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from tests.unit.onyx.db.conftest import model_attrs
|
||||
|
||||
|
||||
class TestScimDALTokens:
|
||||
"""Tests for ScimDAL token operations."""
|
||||
|
||||
def test_create_token_adds_to_session(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
scim_dal.create_token(
|
||||
name="test",
|
||||
hashed_token="abc123",
|
||||
token_display="****abcd",
|
||||
created_by_id=user_id,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert model_attrs(added_obj) == {
|
||||
"name": "test",
|
||||
"hashed_token": "abc123",
|
||||
"token_display": "****abcd",
|
||||
"created_by_id": user_id,
|
||||
}
|
||||
|
||||
def test_get_token_by_hash_queries_session(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
token = ScimToken(
|
||||
id=1,
|
||||
name="test-token",
|
||||
hashed_token="a" * 64,
|
||||
token_display="onyx_scim_****abcd",
|
||||
is_active=True,
|
||||
created_by_id=uuid4(),
|
||||
)
|
||||
mock_db_session.scalar.return_value = token
|
||||
|
||||
result = scim_dal.get_token_by_hash("a" * 64)
|
||||
|
||||
assert result is token
|
||||
mock_db_session.scalar.assert_called_once()
|
||||
|
||||
def test_revoke_token_sets_inactive(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
token = ScimToken(
|
||||
id=1,
|
||||
name="test-token",
|
||||
hashed_token="a" * 64,
|
||||
token_display="onyx_scim_****abcd",
|
||||
is_active=True,
|
||||
created_by_id=uuid4(),
|
||||
)
|
||||
mock_db_session.get.return_value = token
|
||||
expected = model_attrs(token) | {"is_active": False}
|
||||
|
||||
scim_dal.revoke_token(1)
|
||||
|
||||
assert model_attrs(token) == expected
|
||||
|
||||
def test_revoke_nonexistent_token_raises(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
scim_dal.revoke_token(999)
|
||||
|
||||
|
||||
class TestScimDALUserMappings:
|
||||
"""Tests for ScimDAL user mapping operations."""
|
||||
|
||||
def test_create_user_mapping(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
scim_dal.create_user_mapping(external_id="ext-1", user_id=user_id)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert model_attrs(added_obj) == {
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mapping = ScimUserMapping(id=1, external_id="ext-1", user_id=uuid4())
|
||||
mock_db_session.get.return_value = mapping
|
||||
|
||||
scim_dal.delete_user_mapping(1)
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(mapping)
|
||||
|
||||
def test_delete_nonexistent_user_mapping_is_idempotent(
|
||||
self,
|
||||
scim_dal: ScimDAL,
|
||||
mock_db_session: MagicMock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
scim_dal.delete_user_mapping(999)
|
||||
|
||||
mock_db_session.delete.assert_not_called()
|
||||
assert "SCIM user mapping 999 not found" in caplog.text
|
||||
|
||||
def test_update_user_mapping_external_id(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mapping = ScimUserMapping(id=1, external_id="old-id", user_id=uuid4())
|
||||
mock_db_session.get.return_value = mapping
|
||||
expected = model_attrs(mapping) | {"external_id": "new-id"}
|
||||
|
||||
result = scim_dal.update_user_mapping_external_id(1, "new-id")
|
||||
|
||||
assert result is mapping
|
||||
assert model_attrs(result) == expected
|
||||
|
||||
def test_update_nonexistent_user_mapping_raises(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
scim_dal.update_user_mapping_external_id(999, "new-id")
|
||||
|
||||
|
||||
class TestScimDALGroupMappings:
|
||||
"""Tests for ScimDAL group mapping operations."""
|
||||
|
||||
def test_create_group_mapping(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
scim_dal.create_group_mapping(external_id="ext-g1", user_group_id=5)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert model_attrs(added_obj) == {
|
||||
"external_id": "ext-g1",
|
||||
"user_group_id": 5,
|
||||
}
|
||||
|
||||
def test_delete_group_mapping(
|
||||
self, scim_dal: ScimDAL, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mapping = ScimGroupMapping(id=1, external_id="ext-g1", user_group_id=10)
|
||||
mock_db_session.get.return_value = mapping
|
||||
|
||||
scim_dal.delete_group_mapping(1)
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(mapping)
|
||||
|
||||
def test_delete_nonexistent_group_mapping_is_idempotent(
|
||||
self,
|
||||
scim_dal: ScimDAL,
|
||||
mock_db_session: MagicMock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
scim_dal.delete_group_mapping(999)
|
||||
|
||||
mock_db_session.delete.assert_not_called()
|
||||
assert "SCIM group mapping 999 not found" in caplog.text
|
||||
@@ -25,11 +25,6 @@ from onyx.llm.models import UserMessage
|
||||
from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
|
||||
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
|
||||
"claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6",
|
||||
]
|
||||
|
||||
|
||||
def _create_delta(
|
||||
role: str | None = None,
|
||||
@@ -425,16 +420,15 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG)
|
||||
def test_vertex_stream_omits_stream_options(model_name: str) -> None:
|
||||
def test_vertex_stream_omits_stream_options() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name=model_name,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name=model_name,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -474,16 +468,15 @@ def test_openai_auto_reasoning_effort_maps_to_medium() -> None:
|
||||
assert kwargs["reasoning"]["effort"] == "medium"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG)
|
||||
def test_vertex_opus_omits_reasoning_effort(model_name: str) -> None:
|
||||
def test_vertex_opus_4_5_omits_reasoning_effort() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name=model_name,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name=model_name,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Shared fixtures for SCIM endpoint unit tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
"""A MagicMock standing in for a SQLAlchemy Session."""
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token() -> MagicMock:
|
||||
"""A MagicMock standing in for a verified ScimToken."""
|
||||
token = MagicMock(spec=ScimToken)
|
||||
token.id = 1
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dal() -> Generator[MagicMock, None, None]:
|
||||
"""Patch ScimDAL construction in api module and yield the mock instance."""
|
||||
with patch("ee.onyx.server.scim.api.ScimDAL") as cls:
|
||||
dal = cls.return_value
|
||||
# User defaults
|
||||
dal.get_user.return_value = None
|
||||
dal.get_user_by_email.return_value = None
|
||||
dal.get_user_mapping_by_user_id.return_value = None
|
||||
dal.get_user_mapping_by_external_id.return_value = None
|
||||
dal.list_users.return_value = ([], 0)
|
||||
# Group defaults
|
||||
dal.get_group.return_value = None
|
||||
dal.get_group_by_name.return_value = None
|
||||
dal.get_group_mapping_by_group_id.return_value = None
|
||||
dal.get_group_mapping_by_external_id.return_value = None
|
||||
dal.get_group_members.return_value = []
|
||||
dal.list_groups.return_value = ([], 0)
|
||||
yield dal
|
||||
|
||||
|
||||
def make_scim_user(**kwargs: Any) -> ScimUserResource:
|
||||
"""Build a ScimUserResource with sensible defaults."""
|
||||
defaults: dict[str, Any] = {
|
||||
"userName": "test@example.com",
|
||||
"externalId": "ext-default",
|
||||
"active": True,
|
||||
"name": ScimName(givenName="Test", familyName="User"),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ScimUserResource(**defaults)
|
||||
|
||||
|
||||
def make_scim_group(**kwargs: Any) -> ScimGroupResource:
|
||||
"""Build a ScimGroupResource with sensible defaults."""
|
||||
defaults: dict[str, Any] = {"displayName": "Engineering"}
|
||||
defaults.update(kwargs)
|
||||
return ScimGroupResource(**defaults)
|
||||
|
||||
|
||||
def make_db_user(**kwargs: Any) -> MagicMock:
|
||||
"""Build a mock User ORM object with configurable attributes."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = kwargs.get("id", uuid4())
|
||||
user.email = kwargs.get("email", "test@example.com")
|
||||
user.is_active = kwargs.get("is_active", True)
|
||||
user.personal_name = kwargs.get("personal_name", "Test User")
|
||||
user.role = kwargs.get("role", UserRole.BASIC)
|
||||
return user
|
||||
|
||||
|
||||
def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
"""Build a mock UserGroup ORM object with configurable attributes."""
|
||||
group = MagicMock(spec=UserGroup)
|
||||
group.id = kwargs.get("id", 1)
|
||||
group.name = kwargs.get("name", "Engineering")
|
||||
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
|
||||
group.is_up_to_date = kwargs.get("is_up_to_date", True)
|
||||
return group
|
||||
|
||||
|
||||
def assert_scim_error(result: object, expected_status: int) -> None:
|
||||
"""Assert *result* is a JSONResponse with the given status code."""
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == expected_status
|
||||
@@ -1,132 +0,0 @@
|
||||
"""Tests for SCIM admin token management endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.enterprise_settings.api import create_scim_token
|
||||
from ee.onyx.server.enterprise_settings.api import get_active_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scim_dal(mock_db_session: MagicMock) -> ScimDAL:
|
||||
return ScimDAL(mock_db_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> User:
|
||||
user = User(id=uuid4(), email="admin@test.com")
|
||||
user.is_active = True
|
||||
return user
|
||||
|
||||
|
||||
def _make_token(token_id: int, name: str, *, is_active: bool = True) -> ScimToken:
|
||||
return ScimToken(
|
||||
id=token_id,
|
||||
name=name,
|
||||
hashed_token="h" * 64,
|
||||
token_display="onyx_scim_****abcd",
|
||||
is_active=is_active,
|
||||
created_by_id=uuid4(),
|
||||
created_at=datetime(2026, 1, 1),
|
||||
last_used_at=None,
|
||||
)
|
||||
|
||||
|
||||
class TestGetActiveToken:
|
||||
def test_returns_token_metadata(self, scim_dal: ScimDAL, admin_user: User) -> None:
|
||||
token = _make_token(1, "prod-token")
|
||||
scim_dal._session.scalar.return_value = token # type: ignore[attr-defined]
|
||||
|
||||
result = get_active_scim_token(_=admin_user, dal=scim_dal)
|
||||
|
||||
assert result.id == 1
|
||||
assert result.name == "prod-token"
|
||||
assert result.is_active is True
|
||||
|
||||
def test_raises_404_when_no_active_token(
|
||||
self, scim_dal: ScimDAL, admin_user: User
|
||||
) -> None:
|
||||
scim_dal._session.scalar.return_value = None # type: ignore[attr-defined]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_active_scim_token(_=admin_user, dal=scim_dal)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestCreateToken:
|
||||
@patch("ee.onyx.server.enterprise_settings.api.generate_scim_token")
|
||||
def test_creates_token_and_revokes_previous(
|
||||
self,
|
||||
mock_generate: MagicMock,
|
||||
scim_dal: ScimDAL,
|
||||
admin_user: User,
|
||||
) -> None:
|
||||
mock_generate.return_value = ("raw_token_val", "hashed_val", "****abcd")
|
||||
|
||||
# Simulate one existing active token that should get revoked
|
||||
existing = _make_token(1, "old-token", is_active=True)
|
||||
scim_dal._session.scalars.return_value.all.return_value = [existing] # type: ignore[attr-defined]
|
||||
|
||||
# Simulate DB defaults that would be set on INSERT/flush
|
||||
def fake_add(obj: ScimToken) -> None:
|
||||
obj.id = 2
|
||||
obj.is_active = True
|
||||
obj.created_at = datetime(2026, 2, 1)
|
||||
|
||||
scim_dal._session.add.side_effect = fake_add # type: ignore[attr-defined]
|
||||
|
||||
body = ScimTokenCreate(name="new-token")
|
||||
result = create_scim_token(body=body, user=admin_user, dal=scim_dal)
|
||||
|
||||
# Previous token was revoked (by create_token's internal revocation)
|
||||
assert existing.is_active is False
|
||||
|
||||
# New token returned with raw value
|
||||
assert result.raw_token == "raw_token_val"
|
||||
assert result.name == "new-token"
|
||||
assert result.is_active is True
|
||||
|
||||
# Session was committed
|
||||
scim_dal._session.commit.assert_called_once() # type: ignore[attr-defined]
|
||||
|
||||
@patch("ee.onyx.server.enterprise_settings.api.generate_scim_token")
|
||||
def test_creates_first_token_when_none_exist(
|
||||
self,
|
||||
mock_generate: MagicMock,
|
||||
scim_dal: ScimDAL,
|
||||
admin_user: User,
|
||||
) -> None:
|
||||
mock_generate.return_value = ("raw_token_val", "hashed_val", "****abcd")
|
||||
|
||||
# No existing tokens
|
||||
scim_dal._session.scalars.return_value.all.return_value = [] # type: ignore[attr-defined]
|
||||
|
||||
def fake_add(obj: ScimToken) -> None:
|
||||
obj.id = 1
|
||||
obj.is_active = True
|
||||
obj.created_at = datetime(2026, 2, 1)
|
||||
|
||||
scim_dal._session.add.side_effect = fake_add # type: ignore[attr-defined]
|
||||
|
||||
body = ScimTokenCreate(name="first-token")
|
||||
result = create_scim_token(body=body, user=admin_user, dal=scim_dal)
|
||||
|
||||
assert result.raw_token == "raw_token_val"
|
||||
assert result.name == "first-token"
|
||||
assert result.is_active is True
|
||||
@@ -1,103 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.scim.auth import _hash_scim_token
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.auth import SCIM_TOKEN_PREFIX
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
|
||||
|
||||
class TestGenerateScimToken:
|
||||
def test_returns_three_strings(self) -> None:
|
||||
raw, hashed, display = generate_scim_token()
|
||||
assert isinstance(raw, str)
|
||||
assert isinstance(hashed, str)
|
||||
assert isinstance(display, str)
|
||||
|
||||
def test_raw_token_has_prefix(self) -> None:
|
||||
raw, _, _ = generate_scim_token()
|
||||
assert raw.startswith(SCIM_TOKEN_PREFIX)
|
||||
|
||||
def test_hash_is_sha256_hex(self) -> None:
|
||||
raw, hashed, _ = generate_scim_token()
|
||||
assert len(hashed) == 64
|
||||
assert hashed == _hash_scim_token(raw)
|
||||
|
||||
def test_display_shows_last_four_chars(self) -> None:
|
||||
raw, _, display = generate_scim_token()
|
||||
assert display.endswith(raw[-4:])
|
||||
assert "****" in display
|
||||
|
||||
def test_tokens_are_unique(self) -> None:
|
||||
tokens = {generate_scim_token()[0] for _ in range(10)}
|
||||
assert len(tokens) == 10
|
||||
|
||||
|
||||
class TestHashScimToken:
|
||||
def test_deterministic(self) -> None:
|
||||
assert _hash_scim_token("test") == _hash_scim_token("test")
|
||||
|
||||
def test_different_inputs_different_hashes(self) -> None:
|
||||
assert _hash_scim_token("a") != _hash_scim_token("b")
|
||||
|
||||
|
||||
class TestVerifyScimToken:
|
||||
def _make_request(self, auth_header: str | None = None) -> MagicMock:
|
||||
request = MagicMock()
|
||||
headers: dict[str, str] = {}
|
||||
if auth_header is not None:
|
||||
headers["Authorization"] = auth_header
|
||||
request.headers = headers
|
||||
return request
|
||||
|
||||
def _make_dal(self, token: MagicMock | None = None) -> MagicMock:
|
||||
dal = MagicMock()
|
||||
dal.get_token_by_hash.return_value = token
|
||||
return dal
|
||||
|
||||
def test_missing_header_raises_401(self) -> None:
|
||||
request = self._make_request(None)
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing" in str(exc_info.value.detail)
|
||||
|
||||
def test_wrong_prefix_raises_401(self) -> None:
|
||||
request = self._make_request("Bearer on_some_api_key")
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_token_not_in_db_raises_401(self) -> None:
|
||||
raw, _, _ = generate_scim_token()
|
||||
request = self._make_request(f"Bearer {raw}")
|
||||
dal = self._make_dal(token=None)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid" in str(exc_info.value.detail)
|
||||
|
||||
def test_inactive_token_raises_401(self) -> None:
|
||||
raw, _, _ = generate_scim_token()
|
||||
request = self._make_request(f"Bearer {raw}")
|
||||
mock_token = MagicMock()
|
||||
mock_token.is_active = False
|
||||
dal = self._make_dal(token=mock_token)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail)
|
||||
|
||||
def test_valid_token_returns_token(self) -> None:
|
||||
raw, _, _ = generate_scim_token()
|
||||
request = self._make_request(f"Bearer {raw}")
|
||||
mock_token = MagicMock()
|
||||
mock_token.is_active = True
|
||||
dal = self._make_dal(token=mock_token)
|
||||
result = verify_scim_token(request, dal)
|
||||
assert result is mock_token
|
||||
dal.get_token_by_hash.assert_called_once()
|
||||
@@ -1,633 +0,0 @@
|
||||
"""Unit tests for SCIM Group CRUD endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.server.scim.api import create_group
|
||||
from ee.onyx.server.scim.api import delete_group
|
||||
from ee.onyx.server.scim.api import get_group
|
||||
from ee.onyx.server.scim.api import list_groups
|
||||
from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import replace_group
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_group
|
||||
|
||||
|
||||
class TestListGroups:
|
||||
"""Tests for GET /scim/v2/Groups."""
|
||||
|
||||
def test_empty_result(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.list_groups.return_value = ([], 0)
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
|
||||
def test_unsupported_filter_returns_400(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.list_groups.side_effect = ValueError(
|
||||
"Unsupported filter attribute: userName"
|
||||
)
|
||||
|
||||
result = list_groups(
|
||||
filter='userName eq "x"',
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
def test_returns_groups_with_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@example.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimGroupResource)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.externalId == "ext-g-1"
|
||||
assert len(resource.members) == 1
|
||||
assert resource.members[0].display == "alice@example.com"
|
||||
|
||||
|
||||
class TestGetGroup:
|
||||
"""Tests for GET /scim/v2/Groups/{group_id}."""
|
||||
|
||||
def test_returns_scim_resource(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
result = get_group(
|
||||
group_id="5",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "Engineering"
|
||||
assert result.id == "5"
|
||||
|
||||
def test_non_integer_id_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
result = get_group(
|
||||
group_id="not-a-number",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
result = get_group(
|
||||
group_id="999",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
|
||||
class TestCreateGroup:
|
||||
"""Tests for POST /scim/v2/Groups."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_success(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], None)
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
resource = make_scim_group(displayName="New Group")
|
||||
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "New Group"
|
||||
mock_dal.add_group.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_duplicate_name_returns_409(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = make_db_group()
|
||||
resource = make_scim_group()
|
||||
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 409)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_invalid_member_returns_400(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], "Invalid member ID: bad-uuid")
|
||||
|
||||
resource = make_scim_group(members=[ScimGroupMember(value="bad-uuid")])
|
||||
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_nonexistent_member_returns_400(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
uid = uuid4()
|
||||
mock_validate.return_value = ([], f"Member(s) not found: {uid}")
|
||||
|
||||
resource = make_scim_group(members=[ScimGroupMember(value=str(uid))])
|
||||
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_creates_external_id_mapping(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], None)
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
resource = make_scim_group(externalId="ext-g-123")
|
||||
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.create_group_mapping.assert_called_once()
|
||||
|
||||
|
||||
class TestReplaceGroup:
|
||||
"""Tests for PUT /scim/v2/Groups/{group_id}."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_success(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Old Name")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_validate.return_value = ([], None)
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
resource = make_scim_group(displayName="New Name")
|
||||
|
||||
result = replace_group(
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
mock_dal.replace_group_members.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
result = replace_group(
|
||||
group_id="999",
|
||||
group_resource=make_scim_group(),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_invalid_member_returns_400(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_validate.return_value = ([], "Invalid member ID: bad")
|
||||
|
||||
resource = make_scim_group(members=[ScimGroupMember(value="bad")])
|
||||
|
||||
result = replace_group(
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
|
||||
def test_syncs_external_id(
|
||||
self,
|
||||
mock_validate: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_validate.return_value = ([], None)
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
resource = make_scim_group(externalId="new-ext")
|
||||
|
||||
replace_group(
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
mock_dal.sync_group_external_id.assert_called_once_with(5, "new-ext")
|
||||
|
||||
|
||||
class TestPatchGroup:
|
||||
"""Tests for PATCH /scim/v2/Groups/{group_id}."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_rename(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Old Name")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
patched = ScimGroupResource(id="5", displayName="New Name", members=[])
|
||||
mock_apply.return_value = (patched, [], [])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="displayName",
|
||||
value="New Name",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="displayName",
|
||||
value="X",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="999",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_error_returns_error_response(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
mock_apply.side_effect = ScimPatchError("Unsupported path", 400)
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="badPath",
|
||||
value="x",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_add_members(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
mock_dal.validate_member_ids.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(
|
||||
id="5",
|
||||
displayName="Engineering",
|
||||
members=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
mock_apply.return_value = (patched, [uid], [])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.ADD,
|
||||
path="members",
|
||||
value=[{"value": uid}],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.validate_member_ids.assert_called_once()
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_add_nonexistent_member_returns_400(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
uid = uuid4()
|
||||
patched = ScimGroupResource(
|
||||
id="5",
|
||||
displayName="Engineering",
|
||||
members=[ScimGroupMember(value=str(uid))],
|
||||
)
|
||||
mock_apply.return_value = (patched, [str(uid)], [])
|
||||
mock_dal.validate_member_ids.return_value = [uid]
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.ADD,
|
||||
path="members",
|
||||
value=[{"value": str(uid)}],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_remove_members(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(id="5", displayName="Engineering", members=[])
|
||||
mock_apply.return_value = (patched, [], [uid])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REMOVE,
|
||||
path=f'members[value eq "{uid}"]',
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
class TestDeleteGroup:
|
||||
"""Tests for DELETE /scim/v2/Groups/{group_id}."""
|
||||
|
||||
def test_success(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
mapping = MagicMock()
|
||||
mapping.id = 1
|
||||
mock_dal.get_group_mapping_by_group_id.return_value = mapping
|
||||
|
||||
result = delete_group(
|
||||
group_id="5",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
mock_dal.delete_group_mapping.assert_called_once_with(1)
|
||||
mock_dal.delete_group_with_members.assert_called_once_with(group)
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
result = delete_group(
|
||||
group_id="999",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
def test_non_integer_id_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
result = delete_group(
|
||||
group_id="abc",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
@@ -1,521 +0,0 @@
|
||||
"""Unit tests for SCIM User CRUD endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
"""Tests for GET /scim/v2/Users."""
|
||||
|
||||
def test_empty_result(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
|
||||
def test_returns_users_with_scim_shape(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@example.com", personal_name="Alice Smith")
|
||||
mock_dal.list_users.return_value = ([(user, "ext-abc")], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
assert len(result.Resources) == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert resource.userName == "alice@example.com"
|
||||
assert resource.externalId == "ext-abc"
|
||||
|
||||
def test_unsupported_filter_attribute_returns_400(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.list_users.side_effect = ValueError(
|
||||
"Unsupported filter attribute: emails"
|
||||
)
|
||||
|
||||
result = list_users(
|
||||
filter='emails eq "x@y.com"',
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
def test_invalid_filter_syntax_returns_400(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
result = list_users(
|
||||
filter="not a valid filter",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Tests for GET /scim/v2/Users/{user_id}."""
|
||||
|
||||
def test_returns_scim_resource(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "alice@example.com"
|
||||
assert result.id == str(user.id)
|
||||
|
||||
def test_invalid_uuid_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
result = get_user(
|
||||
user_id="not-a-uuid",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
def test_user_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
result = get_user(
|
||||
user_id=str(uuid4()),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
|
||||
class TestCreateUser:
|
||||
"""Tests for POST /scim/v2/Users."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_success(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="new@example.com")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "new@example.com"
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_missing_external_id_returns_400(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_duplicate_email_returns_409(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = make_db_user()
|
||||
resource = make_scim_user()
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 409)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_integrity_error_returns_409(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
mock_dal.add_user.side_effect = IntegrityError("dup", {}, Exception())
|
||||
resource = make_scim_user()
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 409)
|
||||
mock_dal.rollback.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability")
|
||||
def test_seat_limit_returns_403(
|
||||
self,
|
||||
mock_seats: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_seats.return_value = "Seat limit reached"
|
||||
resource = make_scim_user()
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 403)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_creates_external_id_mapping(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(externalId="ext-123")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.externalId == "ext-123"
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
|
||||
|
||||
class TestReplaceUser:
|
||||
"""Tests for PUT /scim/v2/Users/{user_id}."""
|
||||
|
||||
def test_success(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="new@example.com",
|
||||
name=ScimName(givenName="New", familyName="Name"),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(uuid4()),
|
||||
user_resource=make_scim_user(),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability")
|
||||
def test_reactivation_checks_seats(
|
||||
self,
|
||||
mock_seats: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=False)
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_seats.return_value = "No seats"
|
||||
resource = make_scim_user(active=True)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 403)
|
||||
mock_seats.assert_called_once()
|
||||
|
||||
def test_syncs_external_id(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(user.id, None)
|
||||
|
||||
|
||||
class TestPatchUser:
|
||||
"""Tests for PATCH /scim/v2/Users/{user_id}."""
|
||||
|
||||
def test_deactivate(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="active",
|
||||
value=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="active",
|
||||
value=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(uuid4()),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_user_patch")
|
||||
def test_patch_error_returns_error_response(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_apply.side_effect = ScimPatchError("Bad op", 400)
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REMOVE,
|
||||
path="userName",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
|
||||
class TestDeleteUser:
|
||||
"""Tests for DELETE /scim/v2/Users/{user_id}."""
|
||||
|
||||
def test_success(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = MagicMock()
|
||||
mapping.id = 1
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
mock_dal.deactivate_user.assert_called_once_with(user)
|
||||
mock_dal.delete_user_mapping.assert_called_once_with(1)
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(uuid4()),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
def test_invalid_uuid_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
result = delete_user(
|
||||
user_id="not-a-uuid",
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Unit tests for Prometheus instrumentation module."""
|
||||
|
||||
import threading
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.server.prometheus_instrumentation import _slow_request_callback
|
||||
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
|
||||
|
||||
|
||||
def _make_info(
|
||||
duration: float,
|
||||
method: str = "GET",
|
||||
handler: str = "/api/test",
|
||||
status: str = "200",
|
||||
) -> Any:
|
||||
"""Build a fake metrics Info object matching the instrumentator's Info shape."""
|
||||
return MagicMock(
|
||||
modified_duration=duration,
|
||||
method=method,
|
||||
modified_handler=handler,
|
||||
modified_status=status,
|
||||
)
|
||||
|
||||
|
||||
def test_slow_request_callback_increments_above_threshold() -> None:
|
||||
with patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter:
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
|
||||
info = _make_info(
|
||||
duration=2.0, method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
_slow_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_slow_request_callback_skips_below_threshold() -> None:
|
||||
with patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter:
|
||||
info = _make_info(duration=0.5)
|
||||
_slow_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_not_called()
|
||||
|
||||
|
||||
def test_slow_request_callback_skips_at_exact_threshold() -> None:
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.prometheus_instrumentation.SLOW_REQUEST_THRESHOLD_SECONDS", 1.0
|
||||
),
|
||||
patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter,
|
||||
):
|
||||
info = _make_info(duration=1.0)
|
||||
_slow_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_not_called()
|
||||
|
||||
|
||||
def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
with patch("onyx.server.prometheus_instrumentation.Instrumentator") as mock_cls:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.instrument.return_value = mock_instance
|
||||
mock_cls.return_value = mock_instance
|
||||
|
||||
app = FastAPI()
|
||||
setup_prometheus_metrics(app)
|
||||
|
||||
mock_cls.assert_called_once_with(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=False,
|
||||
should_group_untemplated=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=["/health", "/metrics", "/openapi.json"],
|
||||
)
|
||||
mock_instance.add.assert_called_once()
|
||||
mock_instance.instrument.assert_called_once_with(app)
|
||||
mock_instance.expose.assert_called_once_with(app)
|
||||
|
||||
|
||||
def test_inprogress_gauge_increments_during_request() -> None:
|
||||
"""Verify the in-progress gauge goes up while a request is in flight."""
|
||||
registry = CollectorRegistry()
|
||||
gauge = Gauge(
|
||||
"http_requests_inprogress_test",
|
||||
"In-progress requests",
|
||||
["method", "handler"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
request_started = threading.Event()
|
||||
request_release = threading.Event()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/slow")
|
||||
def slow_endpoint() -> dict:
|
||||
gauge.labels(method="GET", handler="/slow").inc()
|
||||
request_started.set()
|
||||
request_release.wait(timeout=5)
|
||||
gauge.labels(method="GET", handler="/slow").dec()
|
||||
return {"status": "done"}
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def make_request() -> None:
|
||||
client.get("/slow")
|
||||
|
||||
thread = threading.Thread(target=make_request)
|
||||
thread.start()
|
||||
|
||||
request_started.wait(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/slow")._value.get() == 1.0
|
||||
|
||||
request_release.set()
|
||||
thread.join(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/slow")._value.get() == 0.0
|
||||
|
||||
|
||||
def test_inprogress_gauge_tracks_concurrent_requests() -> None:
|
||||
"""Verify the gauge correctly counts multiple concurrent in-flight requests."""
|
||||
registry = CollectorRegistry()
|
||||
gauge = Gauge(
|
||||
"http_requests_inprogress_concurrent_test",
|
||||
"In-progress requests",
|
||||
["method", "handler"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
# 3 parties: 2 request threads + main thread
|
||||
barrier = threading.Barrier(3)
|
||||
release = threading.Event()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/concurrent")
|
||||
def concurrent_endpoint() -> dict:
|
||||
gauge.labels(method="GET", handler="/concurrent").inc()
|
||||
barrier.wait(timeout=5)
|
||||
release.wait(timeout=5)
|
||||
gauge.labels(method="GET", handler="/concurrent").dec()
|
||||
return {"status": "done"}
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def make_request() -> None:
|
||||
client.get("/concurrent")
|
||||
|
||||
t1 = threading.Thread(target=make_request)
|
||||
t2 = threading.Thread(target=make_request)
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
# All 3 threads meet here — both requests are in-flight
|
||||
barrier.wait(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/concurrent")._value.get() == 2.0
|
||||
|
||||
release.set()
|
||||
t1.join(timeout=5)
|
||||
t2.join(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/concurrent")._value.get() == 0.0
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.29
|
||||
version: 0.4.27
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -63,7 +63,7 @@ spec:
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
{{ printf "--loglevel=%s" .Values.celery_beat.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
resources:
|
||||
{{- toYaml .Values.celery_beat.resources | nindent 12 }}
|
||||
|
||||
@@ -68,7 +68,7 @@ spec:
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_docfetching.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching",
|
||||
|
||||
@@ -68,7 +68,7 @@ spec:
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_docprocessing.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
"-Q",
|
||||
"docprocessing",
|
||||
|
||||
@@ -65,7 +65,7 @@ spec:
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_heavy.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
|
||||
|
||||
@@ -65,7 +65,7 @@ spec:
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_light.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,opensearch_migration",
|
||||
|
||||
@@ -65,7 +65,7 @@ spec:
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_monitoring.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
|
||||
@@ -65,7 +65,7 @@ spec:
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_primary.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery,periodic_tasks",
|
||||
|
||||
@@ -65,7 +65,7 @@ spec:
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
{{ printf "--loglevel=%s" .Values.celery_worker_user_file_processing.logLevel | quote }},
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user-file-processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,user_file_delete",
|
||||
@@ -108,3 +108,4 @@ spec:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
|
||||
@@ -88,8 +88,6 @@ spec:
|
||||
value: "true"
|
||||
- name: MCP_SERVER_PORT
|
||||
value: "{{ .Values.mcpServer.containerPorts.server }}"
|
||||
- name: MCP_SERVER_HOST
|
||||
value: "{{ .Values.global.host }}"
|
||||
{{- if .Values.mcpServer.corsOrigins }}
|
||||
- name: MCP_SERVER_CORS_ORIGINS
|
||||
value: "{{ .Values.mcpServer.corsOrigins }}"
|
||||
|
||||
@@ -522,7 +522,6 @@ celery_shared:
|
||||
|
||||
celery_beat:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend-celery
|
||||
@@ -543,7 +542,6 @@ celery_beat:
|
||||
|
||||
celery_worker_heavy:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -577,7 +575,6 @@ celery_worker_heavy:
|
||||
|
||||
celery_worker_docprocessing:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -611,7 +608,6 @@ celery_worker_docprocessing:
|
||||
|
||||
celery_worker_light:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -645,7 +641,6 @@ celery_worker_light:
|
||||
|
||||
celery_worker_monitoring:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -679,7 +674,6 @@ celery_worker_monitoring:
|
||||
|
||||
celery_worker_primary:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -713,7 +707,6 @@ celery_worker_primary:
|
||||
|
||||
celery_worker_user_file_processing:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -858,7 +851,6 @@ mcpServer:
|
||||
|
||||
celery_worker_docfetching:
|
||||
replicaCount: 1
|
||||
logLevel: INFO
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
|
||||
@@ -20,9 +20,9 @@ use tauri::Wry;
|
||||
use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use window_vibrancy::{apply_vibrancy, NSVisualEffectMaterial};
|
||||
|
||||
@@ -40,136 +40,6 @@ const TRAY_MENU_OPEN_APP_ID: &str = "tray_open_app";
|
||||
const TRAY_MENU_OPEN_CHAT_ID: &str = "tray_open_chat";
|
||||
const TRAY_MENU_SHOW_IN_BAR_ID: &str = "tray_show_in_menu_bar";
|
||||
const TRAY_MENU_QUIT_ID: &str = "tray_quit";
|
||||
const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
|
||||
(() => {
|
||||
if (window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__) {
|
||||
return;
|
||||
}
|
||||
|
||||
window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__ = true;
|
||||
|
||||
function isChatSessionPage() {
|
||||
try {
|
||||
const currentUrl = new URL(window.location.href);
|
||||
return (
|
||||
currentUrl.pathname.startsWith("/app") &&
|
||||
currentUrl.searchParams.has("chatId")
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function getAllowedNavigationUrl(rawUrl) {
|
||||
try {
|
||||
const parsed = new URL(String(rawUrl), window.location.href);
|
||||
const scheme = parsed.protocol.toLowerCase();
|
||||
if (!["http:", "https:", "mailto:", "tel:"].includes(scheme)) {
|
||||
return null;
|
||||
}
|
||||
return parsed;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function openWithTauri(url) {
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke !== "function") {
|
||||
return false;
|
||||
}
|
||||
|
||||
await invoke("open_in_browser", { url });
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleChatNavigation(rawUrl) {
|
||||
const parsedUrl = getAllowedNavigationUrl(rawUrl);
|
||||
if (!parsedUrl) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const safeUrl = parsedUrl.toString();
|
||||
const scheme = parsedUrl.protocol.toLowerCase();
|
||||
if (scheme === "mailto:" || scheme === "tel:") {
|
||||
void openWithTauri(safeUrl).then((opened) => {
|
||||
if (!opened) {
|
||||
window.location.assign(safeUrl);
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
window.location.assign(safeUrl);
|
||||
return true;
|
||||
}
|
||||
|
||||
document.addEventListener(
|
||||
"click",
|
||||
(event) => {
|
||||
if (!isChatSessionPage() || event.defaultPrevented) {
|
||||
return;
|
||||
}
|
||||
|
||||
const element = event.target;
|
||||
if (!(element instanceof Element)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const anchor = element.closest("a");
|
||||
if (!(anchor instanceof HTMLAnchorElement)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const target = (anchor.getAttribute("target") || "").toLowerCase();
|
||||
if (target !== "_blank") {
|
||||
return;
|
||||
}
|
||||
|
||||
const href = anchor.getAttribute("href");
|
||||
if (!href || href.startsWith("#")) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!handleChatNavigation(href)) {
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
},
|
||||
true
|
||||
);
|
||||
|
||||
const nativeWindowOpen = window.open;
|
||||
window.open = function(url, target, features) {
|
||||
const resolvedTarget = typeof target === "string" ? target.toLowerCase() : "";
|
||||
const shouldNavigateInPlace = resolvedTarget === "" || resolvedTarget === "_blank";
|
||||
|
||||
if (
|
||||
isChatSessionPage() &&
|
||||
shouldNavigateInPlace &&
|
||||
url != null &&
|
||||
String(url).length > 0
|
||||
) {
|
||||
if (!handleChatNavigation(url)) {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
if (typeof nativeWindowOpen === "function") {
|
||||
return nativeWindowOpen.call(window, url, target, features);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
})();
|
||||
"##;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
@@ -307,7 +177,22 @@ fn trigger_new_window(app: &AppHandle) {
|
||||
}
|
||||
|
||||
fn open_docs() {
|
||||
let _ = open_in_default_browser("https://docs.onyx.app");
|
||||
let url = "https://docs.onyx.app";
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let _ = Command::new("open").arg(url).status();
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
let _ = Command::new("xdg-open").arg(url).status();
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let _ = Command::new("rundll32")
|
||||
.arg("url.dll,FileProtocolHandler")
|
||||
.arg(url)
|
||||
.status();
|
||||
}
|
||||
}
|
||||
|
||||
fn open_settings(app: &AppHandle) {
|
||||
@@ -334,68 +219,6 @@ fn open_settings(app: &AppHandle) {
|
||||
}
|
||||
}
|
||||
|
||||
fn same_origin(left: &Url, right: &Url) -> bool {
|
||||
left.scheme() == right.scheme()
|
||||
&& left.host_str() == right.host_str()
|
||||
&& left.port_or_known_default() == right.port_or_known_default()
|
||||
}
|
||||
|
||||
fn is_chat_session_url(url: &Url) -> bool {
|
||||
url.path().starts_with("/app") && url.query_pairs().any(|(key, _)| key == "chatId")
|
||||
}
|
||||
|
||||
fn should_open_in_external_browser(current_url: &Url, destination_url: &Url) -> bool {
|
||||
if !is_chat_session_url(current_url) {
|
||||
return false;
|
||||
}
|
||||
|
||||
match destination_url.scheme() {
|
||||
"mailto" | "tel" => true,
|
||||
"http" | "https" => !same_origin(current_url, destination_url),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn open_in_default_browser(url: &str) -> bool {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
return Command::new("open").arg(url).status().is_ok();
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
return Command::new("xdg-open").arg(url).status().is_ok();
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
return Command::new("rundll32")
|
||||
.arg("url.dll,FileProtocolHandler")
|
||||
.arg(url)
|
||||
.status()
|
||||
.is_ok();
|
||||
}
|
||||
#[allow(unreachable_code)]
|
||||
false
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn open_in_browser(url: String) -> Result<(), String> {
|
||||
let parsed_url = Url::parse(&url).map_err(|_| "Invalid URL".to_string())?;
|
||||
match parsed_url.scheme() {
|
||||
"http" | "https" | "mailto" | "tel" => {}
|
||||
_ => return Err("Unsupported URL scheme".to_string()),
|
||||
}
|
||||
|
||||
if open_in_default_browser(parsed_url.as_str()) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("Failed to open URL in default browser".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn inject_chat_link_intercept(webview: &Webview) {
|
||||
let _ = webview.eval(CHAT_LINK_INTERCEPT_SCRIPT);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tauri Commands
|
||||
// ============================================================================
|
||||
@@ -417,8 +240,8 @@ struct BootstrapState {
|
||||
fn get_bootstrap_state(state: tauri::State<ConfigState>) -> BootstrapState {
|
||||
let server_url = state.config.read().unwrap().server_url.clone();
|
||||
let config_initialized = *state.config_initialized.read().unwrap();
|
||||
let config_exists =
|
||||
config_initialized && get_config_path().map(|path| path.exists()).unwrap_or(false);
|
||||
let config_exists = config_initialized
|
||||
&& get_config_path().map(|path| path.exists()).unwrap_or(false);
|
||||
|
||||
BootstrapState {
|
||||
server_url,
|
||||
@@ -639,13 +462,7 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+N"),
|
||||
)?;
|
||||
let settings_item = MenuItem::with_id(
|
||||
app,
|
||||
"open_settings",
|
||||
"Settings...",
|
||||
true,
|
||||
Some("CmdOrCtrl+Comma"),
|
||||
)?;
|
||||
let settings_item = MenuItem::with_id(app, "open_settings", "Settings...", true, Some("CmdOrCtrl+Comma"))?;
|
||||
let docs_item = MenuItem::with_id(app, "open_docs", "Onyx Documentation", true, None::<&str>)?;
|
||||
|
||||
if let Some(file_menu) = menu
|
||||
@@ -684,7 +501,13 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
}
|
||||
|
||||
fn build_tray_menu(app: &AppHandle) -> tauri::Result<Menu<Wry>> {
|
||||
let open_app = MenuItem::with_id(app, TRAY_MENU_OPEN_APP_ID, "Open Onyx", true, None::<&str>)?;
|
||||
let open_app = MenuItem::with_id(
|
||||
app,
|
||||
TRAY_MENU_OPEN_APP_ID,
|
||||
"Open Onyx",
|
||||
true,
|
||||
None::<&str>,
|
||||
)?;
|
||||
let open_chat = MenuItem::with_id(
|
||||
app,
|
||||
TRAY_MENU_OPEN_CHAT_ID,
|
||||
@@ -775,27 +598,6 @@ fn main() {
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(
|
||||
tauri::plugin::Builder::<Wry>::new("chat-external-navigation-handler")
|
||||
.on_navigation(|webview, destination_url| {
|
||||
let Ok(current_url) = webview.url() else {
|
||||
return true;
|
||||
};
|
||||
|
||||
if should_open_in_external_browser(¤t_url, destination_url) {
|
||||
if !open_in_default_browser(destination_url.as_str()) {
|
||||
eprintln!(
|
||||
"Failed to open external URL in default browser: {}",
|
||||
destination_url
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.build(),
|
||||
)
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
@@ -807,7 +609,6 @@ fn main() {
|
||||
get_bootstrap_state,
|
||||
set_server_url,
|
||||
get_config_path_cmd,
|
||||
open_in_browser,
|
||||
open_config_file,
|
||||
open_config_directory,
|
||||
navigate_to,
|
||||
@@ -860,12 +661,10 @@ fn main() {
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
|
||||
inject_chat_link_intercept(webview);
|
||||
|
||||
.on_page_load(|_webview: &Webview, _payload: &PageLoadPayload| {
|
||||
// Re-inject titlebar after every navigation/page load (macOS only)
|
||||
#[cfg(target_os = "macos")]
|
||||
let _ = webview.eval(TITLEBAR_SCRIPT);
|
||||
let _ = _webview.eval(TITLEBAR_SCRIPT);
|
||||
})
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
# Onyx Prometheus Metrics Reference
|
||||
|
||||
## API Server Metrics
|
||||
|
||||
These metrics are exposed at `GET /metrics` on the API server.
|
||||
|
||||
### Built-in (via `prometheus-fastapi-instrumentator`)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
|
||||
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
|
||||
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (few buckets for aggregation) |
|
||||
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
|
||||
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
|
||||
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
|
||||
|
||||
### Custom (via `onyx.server.prometheus_instrumentation`)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_api_slow_requests_total` | Counter | `method`, `handler`, `status` | Requests exceeding `SLOW_REQUEST_THRESHOLD_SECONDS` (default 1s) |
|
||||
|
||||
### Configuration
|
||||
|
||||
| Env Var | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
|
||||
|
||||
### Instrumentator Settings
|
||||
|
||||
- `should_group_status_codes=False` — Reports exact HTTP status codes (e.g. 401, 403, 500)
|
||||
- `should_instrument_requests_inprogress=True` — Enables the in-progress request gauge
|
||||
- `inprogress_labels=True` — Breaks down in-progress gauge by `method` and `handler`
|
||||
- `excluded_handlers=["/health", "/metrics", "/openapi.json"]` — Excludes noisy endpoints from metrics
|
||||
|
||||
## Example PromQL Queries
|
||||
|
||||
### Which endpoints are saturated right now?
|
||||
|
||||
```promql
|
||||
# Top 10 endpoints by in-progress requests
|
||||
topk(10, http_requests_inprogress)
|
||||
```
|
||||
|
||||
### What's the P99 latency per endpoint?
|
||||
|
||||
```promql
|
||||
# P99 latency by handler over the last 5 minutes
|
||||
histogram_quantile(0.99, sum by (handler, le) (rate(http_request_duration_seconds_bucket[5m])))
|
||||
```
|
||||
|
||||
### Which endpoints have the highest request rate?
|
||||
|
||||
```promql
|
||||
# Requests per second by handler, top 10
|
||||
topk(10, sum by (handler) (rate(http_requests_total[5m])))
|
||||
```
|
||||
|
||||
### Which endpoints are returning errors?
|
||||
|
||||
```promql
|
||||
# 5xx error rate by handler
|
||||
sum by (handler) (rate(http_requests_total{status=~"5.."}[5m]))
|
||||
```
|
||||
|
||||
### Slow request hotspots
|
||||
|
||||
```promql
|
||||
# Slow requests per minute by handler
|
||||
sum by (handler) (rate(onyx_api_slow_requests_total[5m])) * 60
|
||||
```
|
||||
|
||||
### Latency trending up?
|
||||
|
||||
```promql
|
||||
# Compare P50 latency now vs 1 hour ago
|
||||
histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m])))
|
||||
-
|
||||
histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m] offset 1h)))
|
||||
```
|
||||
|
||||
### Overall request throughput
|
||||
|
||||
```promql
|
||||
# Total requests per second across all endpoints
|
||||
sum(rate(http_requests_total[5m]))
|
||||
```
|
||||
1
web/.gitignore
vendored
1
web/.gitignore
vendored
@@ -38,7 +38,6 @@ next-env.d.ts
|
||||
|
||||
# playwright testing temp files
|
||||
/admin*_auth.json
|
||||
/worker*_auth.json
|
||||
/user_auth.json
|
||||
/build-archive.log
|
||||
/test-results
|
||||
|
||||
@@ -149,9 +149,9 @@ interface InteractiveBasePropsBase
|
||||
/**
|
||||
* URL to navigate to when clicked.
|
||||
*
|
||||
* Passed through Slot to the child element (typically `Interactive.Container`),
|
||||
* which renders an `<a>` tag when `href` is present. This keeps all styling
|
||||
* (backgrounds, rounding, overflow) on a single element.
|
||||
* When provided, renders an `<a>` wrapper element instead of using Radix Slot.
|
||||
* The `<a>` receives all interactive styling (hover/active/transient states)
|
||||
* and children are rendered inside it.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
@@ -261,32 +261,34 @@ function InteractiveBase({
|
||||
"aria-disabled": disabled || undefined,
|
||||
};
|
||||
|
||||
if (href) {
|
||||
const { children, onClick, ...rest } = props;
|
||||
return (
|
||||
<a
|
||||
ref={ref as React.Ref<HTMLAnchorElement>}
|
||||
href={disabled ? undefined : href}
|
||||
target={target}
|
||||
rel={target === "_blank" ? "noopener noreferrer" : undefined}
|
||||
className={classes}
|
||||
{...dataAttrs}
|
||||
{...rest}
|
||||
onClick={
|
||||
disabled ? (e: React.MouseEvent) => e.preventDefault() : onClick
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
const { onClick, ...slotProps } = props;
|
||||
|
||||
// href, target, and rel are passed through Slot to the child element
|
||||
// (typically Interactive.Container), which renders an <a> when href is present.
|
||||
const linkAttrs = href
|
||||
? {
|
||||
href: disabled ? undefined : href,
|
||||
target,
|
||||
rel: target === "_blank" ? "noopener noreferrer" : undefined,
|
||||
}
|
||||
: {};
|
||||
|
||||
return (
|
||||
<Slot
|
||||
ref={ref}
|
||||
className={classes}
|
||||
{...dataAttrs}
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
disabled && href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: disabled
|
||||
? undefined
|
||||
: onClick
|
||||
}
|
||||
onClick={disabled ? undefined : onClick}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -314,8 +316,6 @@ interface InteractiveContainerProps
|
||||
* This keeps all styling (background, rounding, height) on a single
|
||||
* element — unlike a wrapper approach which would split them.
|
||||
*
|
||||
* Mutually exclusive with `href`.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Interactive.Base>
|
||||
@@ -399,22 +399,15 @@ function InteractiveContainer({
|
||||
heightVariant = "lg",
|
||||
...props
|
||||
}: InteractiveContainerProps) {
|
||||
// Radix Slot injects className, style, href, target, rel, and other
|
||||
// attributes at runtime (bypassing WithoutStyles), so we extract and
|
||||
// merge them to preserve the Slot-injected values.
|
||||
// Radix Slot injects className and style at runtime (bypassing WithoutStyles),
|
||||
// so we extract and merge them to preserve the Slot-injected values.
|
||||
const {
|
||||
className: slotClassName,
|
||||
style: slotStyle,
|
||||
href,
|
||||
target,
|
||||
rel,
|
||||
...rest
|
||||
} = props as typeof props & {
|
||||
className?: string;
|
||||
style?: React.CSSProperties;
|
||||
href?: string;
|
||||
target?: string;
|
||||
rel?: string;
|
||||
};
|
||||
const sharedProps = {
|
||||
...rest,
|
||||
@@ -430,20 +423,6 @@ function InteractiveContainer({
|
||||
style: slotStyle,
|
||||
};
|
||||
|
||||
// When href is provided (via Slot from Interactive.Base), render an <a>
|
||||
// so all styling (backgrounds, rounding, overflow) lives on one element.
|
||||
if (href) {
|
||||
return (
|
||||
<a
|
||||
ref={ref as React.Ref<HTMLAnchorElement>}
|
||||
href={href}
|
||||
target={target}
|
||||
rel={rel}
|
||||
{...(sharedProps as React.HTMLAttributes<HTMLAnchorElement>)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type) {
|
||||
// When Interactive.Base is disabled it injects aria-disabled via Slot.
|
||||
// Map that to the native disabled attribute so a <button type="submit">
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
/* Interactive — container */
|
||||
.interactive-container {
|
||||
@apply flex items-center justify-center overflow-clip;
|
||||
@apply flex self-stretch items-center justify-center overflow-clip;
|
||||
}
|
||||
.interactive-container[data-border="true"] {
|
||||
@apply border;
|
||||
|
||||
7
web/package-lock.json
generated
7
web/package-lock.json
generated
@@ -53,7 +53,6 @@
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-cookie": "^3.0.5",
|
||||
"katex": "^0.16.17",
|
||||
"linguist-languages": "^9.3.1",
|
||||
"lodash": "^4.17.23",
|
||||
"lowlight": "^3.3.0",
|
||||
"lucide-react": "^0.454.0",
|
||||
@@ -11203,12 +11202,6 @@
|
||||
"version": "1.2.4",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/linguist-languages": {
|
||||
"version": "9.3.1",
|
||||
"resolved": "https://registry.npmjs.org/linguist-languages/-/linguist-languages-9.3.1.tgz",
|
||||
"integrity": "sha512-Mum2sqg3MyhgKfpulFhKZMAK/1VnV6m9vCV8YQCSqWs+pbKouKn9EqRshZjVWUaJjl6NTTDcYJk/1+C02siXEQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/loader-runner": {
|
||||
"version": "4.3.1",
|
||||
"license": "MIT",
|
||||
|
||||
@@ -68,7 +68,6 @@
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-cookie": "^3.0.5",
|
||||
"katex": "^0.16.17",
|
||||
"linguist-languages": "^9.3.1",
|
||||
"lodash": "^4.17.23",
|
||||
"lowlight": "^3.3.0",
|
||||
"lucide-react": "^0.454.0",
|
||||
|
||||
@@ -161,16 +161,10 @@ function PlanCard({
|
||||
>
|
||||
{buttonLabel}
|
||||
</Button>
|
||||
) : onClick ? (
|
||||
) : (
|
||||
<Button main primary onClick={onClick} leftIcon={ButtonIcon}>
|
||||
{buttonLabel}
|
||||
</Button>
|
||||
) : (
|
||||
<Button tertiary transient className="pointer-events-none">
|
||||
<Text mainUiAction text03>
|
||||
Included in your plan
|
||||
</Text>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</Section>
|
||||
@@ -227,14 +221,12 @@ function PlanCard({
|
||||
|
||||
interface PlansViewProps {
|
||||
hasSubscription?: boolean;
|
||||
hasLicense?: boolean;
|
||||
onCheckout: () => void;
|
||||
hideFeatures?: boolean;
|
||||
}
|
||||
|
||||
export default function PlansView({
|
||||
hasSubscription,
|
||||
hasLicense,
|
||||
onCheckout,
|
||||
hideFeatures,
|
||||
}: PlansViewProps) {
|
||||
@@ -247,10 +239,10 @@ export default function PlansView({
|
||||
"per seat/month billed annually\nor $25 per seat if billed monthly",
|
||||
buttonLabel: "Get Business Plan",
|
||||
buttonVariant: "primary",
|
||||
onClick: hasLicense ? undefined : onCheckout,
|
||||
onClick: onCheckout,
|
||||
features: BUSINESS_FEATURES,
|
||||
featuresPrefix: "Get more work done with AI for your team.",
|
||||
isCurrentPlan: !!hasSubscription,
|
||||
isCurrentPlan: hasSubscription,
|
||||
},
|
||||
{
|
||||
icon: SvgOrganization,
|
||||
@@ -262,7 +254,6 @@ export default function PlansView({
|
||||
href: SALES_URL,
|
||||
features: ENTERPRISE_FEATURES,
|
||||
featuresPrefix: "Everything in Business Plan, plus:",
|
||||
isCurrentPlan: !!hasLicense && !hasSubscription,
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
@@ -293,7 +293,6 @@ export default function BillingPage() {
|
||||
plans: (
|
||||
<PlansView
|
||||
hasSubscription={!!hasSubscription}
|
||||
hasLicense={!!licenseData?.has_license}
|
||||
onCheckout={() => changeView("checkout")}
|
||||
hideFeatures={showLicenseActivationInput}
|
||||
/>
|
||||
|
||||
@@ -125,7 +125,6 @@ export const WebProviderSetupModal = memo(
|
||||
<FormField.Label>API Key</FormField.Label>
|
||||
<FormField.Control asChild>
|
||||
<PasswordInputTypeIn
|
||||
data-testid="web-provider-api-key-input"
|
||||
placeholder="Enter API key"
|
||||
value={apiKeyValue}
|
||||
autoFocus={apiKeyAutoFocus}
|
||||
|
||||
242
web/src/app/app/components/modal/ShareChatSessionModal.tsx
Normal file
242
web/src/app/app/components/modal/ShareChatSessionModal.tsx
Normal file
@@ -0,0 +1,242 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Text from "@/components/ui/text";
|
||||
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { structureValue } from "@/lib/llm/utils";
|
||||
import { LlmDescriptor, useLlmManager } from "@/lib/hooks";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCurrentAgent } from "@/hooks/useAgents";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useChatSessionStore } from "@/app/app/stores/useChatSessionStore";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import { copyAll } from "@/app/app/message/copyingUtils";
|
||||
import { SvgCopy, SvgShare } from "@opal/icons";
|
||||
|
||||
function buildShareLink(chatSessionId: string) {
|
||||
const baseUrl = `${window.location.protocol}//${window.location.host}`;
|
||||
return `${baseUrl}/app/shared/${chatSessionId}`;
|
||||
}
|
||||
|
||||
async function generateShareLink(chatSessionId: string) {
|
||||
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ sharing_status: "public" }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return buildShareLink(chatSessionId);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async function generateSeedLink(
|
||||
message?: string,
|
||||
assistantId?: number,
|
||||
modelOverride?: LlmDescriptor
|
||||
) {
|
||||
const baseUrl = `${window.location.protocol}//${window.location.host}`;
|
||||
const model = modelOverride
|
||||
? structureValue(
|
||||
modelOverride.name,
|
||||
modelOverride.provider,
|
||||
modelOverride.modelName
|
||||
)
|
||||
: null;
|
||||
return `${baseUrl}/app${
|
||||
message
|
||||
? `?${SEARCH_PARAM_NAMES.USER_PROMPT}=${encodeURIComponent(message)}`
|
||||
: ""
|
||||
}${
|
||||
assistantId
|
||||
? `${message ? "&" : "?"}${SEARCH_PARAM_NAMES.PERSONA_ID}=${assistantId}`
|
||||
: ""
|
||||
}${
|
||||
model
|
||||
? `${message || assistantId ? "&" : "?"}${
|
||||
SEARCH_PARAM_NAMES.STRUCTURED_MODEL
|
||||
}=${encodeURIComponent(model)}`
|
||||
: ""
|
||||
}${message ? `&${SEARCH_PARAM_NAMES.SEND_ON_LOAD}=true` : ""}`;
|
||||
}
|
||||
|
||||
async function deleteShareLink(chatSessionId: string) {
|
||||
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ sharing_status: "private" }),
|
||||
});
|
||||
|
||||
return response.ok;
|
||||
}
|
||||
|
||||
interface ShareChatSessionModalProps {
|
||||
chatSession: ChatSession;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function ShareChatSessionModal({
|
||||
chatSession,
|
||||
onClose,
|
||||
}: ShareChatSessionModalProps) {
|
||||
const [shareLink, setShareLink] = useState<string>(
|
||||
chatSession.shared_status === ChatSessionSharedStatus.Public
|
||||
? buildShareLink(chatSession.id)
|
||||
: ""
|
||||
);
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const currentAgent = useCurrentAgent();
|
||||
const searchParams = useSearchParams();
|
||||
const message = searchParams?.get(SEARCH_PARAM_NAMES.USER_PROMPT) || "";
|
||||
const llmManager = useLlmManager(chatSession, currentAgent || undefined);
|
||||
const updateCurrentChatSessionSharedStatus = useChatSessionStore(
|
||||
(state) => state.updateCurrentChatSessionSharedStatus
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgShare}
|
||||
title="Share Chat"
|
||||
onClose={onClose}
|
||||
submit={<Button onClick={onClose}>Share</Button>}
|
||||
>
|
||||
{shareLink ? (
|
||||
<div>
|
||||
<Text>
|
||||
This chat session is currently shared. Anyone in your team can
|
||||
view the message history using the following link:
|
||||
</Text>
|
||||
|
||||
<div className="flex items-center mt-2">
|
||||
{/* <CopyButton content={shareLink} /> */}
|
||||
<CopyIconButton
|
||||
getCopyText={() => shareLink}
|
||||
prominence="secondary"
|
||||
/>
|
||||
<a
|
||||
href={shareLink}
|
||||
target="_blank"
|
||||
className={cn(
|
||||
"underline mt-1 ml-1 text-sm my-auto",
|
||||
"text-action-link-05"
|
||||
)}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{shareLink}
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<Text className={cn("mb-4")}>
|
||||
Click the button below to make the chat private again.
|
||||
</Text>
|
||||
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const success = await deleteShareLink(chatSession.id);
|
||||
if (success) {
|
||||
setShareLink("");
|
||||
updateCurrentChatSessionSharedStatus(
|
||||
ChatSessionSharedStatus.Private
|
||||
);
|
||||
} else {
|
||||
alert("Failed to delete share link");
|
||||
}
|
||||
}}
|
||||
danger
|
||||
>
|
||||
Delete Share Link
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Callout type="warning" title="Warning">
|
||||
Please make sure that all content in this chat is safe to share
|
||||
with the whole team.
|
||||
</Callout>
|
||||
<Button
|
||||
leftIcon={SvgCopy}
|
||||
onClick={async () => {
|
||||
// NOTE: for "insecure" non-https setup, the `navigator.clipboard.writeText` may fail
|
||||
// as the browser may not allow the clipboard to be accessed.
|
||||
try {
|
||||
const shareLink = await generateShareLink(chatSession.id);
|
||||
if (!shareLink) {
|
||||
alert("Failed to generate share link");
|
||||
} else {
|
||||
setShareLink(shareLink);
|
||||
updateCurrentChatSessionSharedStatus(
|
||||
ChatSessionSharedStatus.Public
|
||||
);
|
||||
copyAll(shareLink);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}}
|
||||
secondary
|
||||
>
|
||||
Generate and Copy Share Link
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Separator className={cn("my-4")} />
|
||||
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
title="Advanced Options"
|
||||
/>
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Callout type="notice" title="Seed New Chat">
|
||||
Generate a link to a new chat session with the same settings as
|
||||
this chat (including the assistant and model).
|
||||
</Callout>
|
||||
<Button
|
||||
leftIcon={SvgCopy}
|
||||
onClick={async () => {
|
||||
try {
|
||||
const seedLink = await generateSeedLink(
|
||||
message,
|
||||
currentAgent?.id,
|
||||
llmManager.currentLlm
|
||||
);
|
||||
if (!seedLink) {
|
||||
toast.error("Failed to generate seed link");
|
||||
} else {
|
||||
navigator.clipboard.writeText(seedLink);
|
||||
copyAll(seedLink);
|
||||
toast.success("Link copied to clipboard!");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
alert("Failed to generate or copy link.");
|
||||
}
|
||||
}}
|
||||
secondary
|
||||
>
|
||||
Generate and Copy Seed Link
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -38,7 +38,7 @@ export default function ProjectChatSessionList() {
|
||||
if (!currentProjectId) return null;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2 px-2 w-full max-w-[800px] mx-auto mt-4">
|
||||
<div className="flex flex-col gap-2 px-2 w-full mx-auto mt-4">
|
||||
<div className="flex items-center pl-2">
|
||||
<Text as="p" text02 secondaryBody>
|
||||
Recent Chats
|
||||
@@ -56,7 +56,7 @@ export default function ProjectChatSessionList() {
|
||||
No chats yet.
|
||||
</Text>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 max-h-[46vh] overflow-y-auto overscroll-y-none">
|
||||
<div className="flex flex-col gap-2">
|
||||
{projectChats.map((chat) => (
|
||||
<Link
|
||||
key={chat.id}
|
||||
|
||||
@@ -186,7 +186,6 @@ export interface BackendChatSession {
|
||||
current_temperature_override: number | null;
|
||||
current_alternate_model?: string;
|
||||
|
||||
owner_name: string | null;
|
||||
packets: Packet[][];
|
||||
}
|
||||
|
||||
|
||||
@@ -192,10 +192,10 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
/>
|
||||
) : typeof content === "string" ? (
|
||||
<>
|
||||
<div className="md:max-w-[37.5rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
|
||||
<div className="md:max-w-[25rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
|
||||
<div
|
||||
className={
|
||||
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
|
||||
"max-w-[25rem] whitespace-break-spaces rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
|
||||
}
|
||||
onCopy={(e) => {
|
||||
const selection = window.getSelection();
|
||||
@@ -214,7 +214,7 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
{onEdit && !isEditing && (
|
||||
{onEdit && !isEditing && (!files || files.length === 0) && (
|
||||
<div className="flex flex-row p-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<CopyIconButton
|
||||
getCopyText={() => content}
|
||||
@@ -236,7 +236,7 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
<div
|
||||
className={cn(
|
||||
"my-auto",
|
||||
onEdit && !isEditing
|
||||
onEdit && !isEditing && (!files || files.length === 0)
|
||||
? "opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
: "invisible"
|
||||
)}
|
||||
|
||||
@@ -209,10 +209,7 @@ export default function MessageToolbar({
|
||||
<FeedbackModal {...feedbackModalProps!} />
|
||||
</modal.Provider>
|
||||
|
||||
<div
|
||||
data-testid="AgentMessage/toolbar"
|
||||
className="flex md:flex-row justify-between items-center w-full transition-transform duration-300 ease-in-out transform opacity-100 pl-1"
|
||||
>
|
||||
<div className="flex md:flex-row justify-between items-center w-full transition-transform duration-300 ease-in-out transform opacity-100 pl-1">
|
||||
<TooltipGroup>
|
||||
<div className="flex items-center">
|
||||
{includeMessageSwitcher && (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useRef, useCallback, useMemo } from "react";
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
@@ -47,12 +47,6 @@ import { useAppBackground } from "@/providers/AppBackgroundProvider";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import DocumentsSidebar from "@/sections/document-sidebar/DocumentsSidebar";
|
||||
import TextViewModal from "@/sections/modals/TextViewModal";
|
||||
import { personaIncludesRetrieval } from "@/app/app/services/lib";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { eeGated } from "@/ce";
|
||||
import EESearchUI from "@/ee/sections/SearchUI";
|
||||
|
||||
const SearchUI = eeGated(EESearchUI);
|
||||
|
||||
interface NRFPageProps {
|
||||
isSidePanel?: boolean;
|
||||
@@ -180,20 +174,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
const autoScrollEnabled = user?.preferences?.auto_scroll !== false;
|
||||
const isStreaming = currentChatState === "streaming";
|
||||
|
||||
// Query controller for search/chat classification (EE feature)
|
||||
const { submit: submitQuery, classification } = useQueryController();
|
||||
|
||||
// Determine if retrieval (search) is enabled based on the assistant
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
if (liveAssistant) {
|
||||
return personaIncludesRetrieval(liveAssistant);
|
||||
}
|
||||
return false;
|
||||
}, [liveAssistant]);
|
||||
|
||||
// Check if we're in search mode
|
||||
const isSearch = classification === "search";
|
||||
|
||||
// Anchor for scroll positioning (matches ChatPage pattern)
|
||||
const anchorMessage = messageHistory.at(-2) ?? messageHistory[0];
|
||||
const anchorNodeId = anchorMessage?.nodeId;
|
||||
@@ -268,48 +248,17 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
[handleMessageSpecificFileUpload]
|
||||
);
|
||||
|
||||
// Handler for chat submission (used by query controller)
|
||||
const onChat = useCallback(
|
||||
(chatMessage: string) => {
|
||||
resetInputBar();
|
||||
// Handle submit from AppInputBar
|
||||
const handleChatInputSubmit = useCallback(
|
||||
(submittedMessage: string) => {
|
||||
if (!submittedMessage.trim()) return;
|
||||
onSubmit({
|
||||
message: chatMessage,
|
||||
message: submittedMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
});
|
||||
},
|
||||
[onSubmit, currentMessageFiles, deepResearchEnabled, resetInputBar]
|
||||
);
|
||||
|
||||
// Handle submit from AppInputBar - routes through query controller for search/chat classification
|
||||
const handleChatInputSubmit = useCallback(
|
||||
async (submittedMessage: string) => {
|
||||
if (!submittedMessage.trim()) return;
|
||||
// If we already have messages (chat session started), always use chat mode
|
||||
// (matches AppPage behavior where existing sessions bypass classification)
|
||||
if (hasMessages) {
|
||||
resetInputBar();
|
||||
onSubmit({
|
||||
message: submittedMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
});
|
||||
return;
|
||||
}
|
||||
// Use submitQuery which will classify the query and either:
|
||||
// - Route to search (sets classification to "search" and shows SearchUI)
|
||||
// - Route to chat (calls onChat callback)
|
||||
await submitQuery(submittedMessage, onChat);
|
||||
},
|
||||
[
|
||||
hasMessages,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabled,
|
||||
resetInputBar,
|
||||
submitQuery,
|
||||
onChat,
|
||||
]
|
||||
[onSubmit, currentMessageFiles, deepResearchEnabled]
|
||||
);
|
||||
|
||||
// Handle resubmit last message on error
|
||||
@@ -335,12 +284,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
window.open(`${window.location.origin}/app`, "_blank");
|
||||
};
|
||||
|
||||
// Handle search result document click
|
||||
const handleSearchDocumentClick = useCallback(
|
||||
(doc: MinimalOnyxDocument) => setPresentingDocument(doc),
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -422,8 +365,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Welcome message - centered when no messages and not in search mode */}
|
||||
{!hasMessages && !isSearch && (
|
||||
{/* Welcome message - centered when no messages */}
|
||||
{!hasMessages && (
|
||||
<div className="relative w-full flex-1 flex flex-col items-center justify-end">
|
||||
<WelcomeMessage isDefaultAgent />
|
||||
<Spacer rem={1.5} />
|
||||
@@ -443,7 +386,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
removeDocs={() => {}}
|
||||
retrievalEnabled={retrievalEnabled}
|
||||
retrievalEnabled={false}
|
||||
selectedDocuments={[]}
|
||||
initialMessage={message}
|
||||
stopGenerating={stopGenerating}
|
||||
@@ -460,16 +403,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
<Spacer rem={0.5} />
|
||||
</div>
|
||||
|
||||
{/* Search results - shown when query is classified as search */}
|
||||
{isSearch && (
|
||||
<div className="flex-1 w-full max-w-[var(--app-page-main-content-width)] px-4 min-h-0 overflow-auto">
|
||||
<Spacer rem={0.75} />
|
||||
<SearchUI onDocumentClick={handleSearchDocumentClick} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Spacer to push content up when showing welcome message */}
|
||||
{!hasMessages && !isSearch && <div className="flex-1 w-full" />}
|
||||
{!hasMessages && <div className="flex-1 w-full" />}
|
||||
</div>
|
||||
)}
|
||||
</Dropzone>
|
||||
|
||||
@@ -15,7 +15,6 @@ import TextViewModal from "@/sections/modals/TextViewModal";
|
||||
import { UNNAMED_CHAT } from "@/lib/constants";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
import SharedAppInputBar from "@/sections/input/SharedAppInputBar";
|
||||
|
||||
export interface SharedChatDisplayProps {
|
||||
chatSession: BackendChatSession | null;
|
||||
@@ -70,78 +69,65 @@ export default function SharedChatDisplay({
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col h-full w-full overflow-hidden">
|
||||
<div className="flex-1 flex flex-col items-center overflow-y-auto">
|
||||
<div className="sticky top-0 z-10 flex items-center justify-between w-full bg-background-tint-01 px-8 py-4">
|
||||
<Text as="p" text04 headingH2>
|
||||
{chatSession.description || UNNAMED_CHAT}
|
||||
</Text>
|
||||
<div className="flex flex-col items-end">
|
||||
<Text as="p" text03 secondaryBody>
|
||||
Shared on {humanReadableFormat(chatSession.time_created)}
|
||||
</Text>
|
||||
{chatSession.owner_name && (
|
||||
<Text as="p" text03 secondaryBody>
|
||||
by {chatSession.owner_name}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col items-center h-full w-full overflow-hidden overflow-y-scroll">
|
||||
<div className="sticky top-0 z-10 flex flex-col w-full bg-background-tint-01 px-8 py-4">
|
||||
<Text as="p" headingH2>
|
||||
{chatSession.description || UNNAMED_CHAT}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
{humanReadableFormat(chatSession.time_created)}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{isMounted ? (
|
||||
<div className="w-[min(50rem,100%)]">
|
||||
{messages.map((message, i) => {
|
||||
if (message.type === "user") {
|
||||
return (
|
||||
<HumanMessage
|
||||
key={message.messageId}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
nodeId={message.nodeId}
|
||||
/>
|
||||
);
|
||||
} else if (message.type === "assistant") {
|
||||
return (
|
||||
<AgentMessage
|
||||
key={message.messageId}
|
||||
rawPackets={message.packets}
|
||||
chatState={{
|
||||
assistant: persona,
|
||||
docs: message.documents,
|
||||
citations: message.citations,
|
||||
setPresentingDocument: setPresentingDocument,
|
||||
overriddenModel: message.overridden_model,
|
||||
}}
|
||||
nodeId={message.nodeId}
|
||||
llmManager={null}
|
||||
otherMessagesCanSwitchTo={undefined}
|
||||
onMessageSelection={undefined}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
// Error message case
|
||||
return (
|
||||
<div key={message.messageId} className="py-5 ml-4 lg:px-5">
|
||||
<div className="mx-auto w-[90%] max-w-message-max">
|
||||
<p className="text-status-text-error-05 text-sm my-auto">
|
||||
{message.message}
|
||||
</p>
|
||||
</div>
|
||||
{isMounted ? (
|
||||
<div className="w-[min(50rem,100%)]">
|
||||
{messages.map((message, i) => {
|
||||
if (message.type === "user") {
|
||||
return (
|
||||
<HumanMessage
|
||||
key={message.messageId}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
nodeId={message.nodeId}
|
||||
/>
|
||||
);
|
||||
} else if (message.type === "assistant") {
|
||||
return (
|
||||
<AgentMessage
|
||||
key={message.messageId}
|
||||
rawPackets={message.packets}
|
||||
chatState={{
|
||||
assistant: persona,
|
||||
docs: message.documents,
|
||||
citations: message.citations,
|
||||
setPresentingDocument: setPresentingDocument,
|
||||
overriddenModel: message.overridden_model,
|
||||
}}
|
||||
nodeId={message.nodeId}
|
||||
llmManager={null}
|
||||
otherMessagesCanSwitchTo={undefined}
|
||||
onMessageSelection={undefined}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
// Error message case
|
||||
return (
|
||||
<div key={message.messageId} className="py-5 ml-4 lg:px-5">
|
||||
<div className="mx-auto w-[90%] max-w-message-max">
|
||||
<p className="text-status-text-error-05 text-sm my-auto">
|
||||
{message.message}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<div className="h-full w-full flex items-center justify-center">
|
||||
<OnyxInitializingLoader />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="w-full max-w-[50rem] mx-auto px-4 pb-4">
|
||||
<SharedAppInputBar />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<div className="h-full w-full flex items-center justify-center">
|
||||
<OnyxInitializingLoader />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -12,8 +12,7 @@ export type AppFocusType =
|
||||
| { type: "agent" | "project" | "chat"; id: string }
|
||||
| "new-session"
|
||||
| "more-agents"
|
||||
| "user-settings"
|
||||
| "shared-chat";
|
||||
| "user-settings";
|
||||
|
||||
export class AppFocus {
|
||||
constructor(public value: AppFocusType) {}
|
||||
@@ -30,10 +29,6 @@ export class AppFocus {
|
||||
return typeof this.value === "object" && this.value.type === "chat";
|
||||
}
|
||||
|
||||
isSharedChat(): boolean {
|
||||
return this.value === "shared-chat";
|
||||
}
|
||||
|
||||
isNewSession(): boolean {
|
||||
return this.value === "new-session";
|
||||
}
|
||||
@@ -54,7 +49,6 @@ export class AppFocus {
|
||||
| "agent"
|
||||
| "project"
|
||||
| "chat"
|
||||
| "shared-chat"
|
||||
| "new-session"
|
||||
| "more-agents"
|
||||
| "user-settings" {
|
||||
@@ -66,11 +60,6 @@ export default function useAppFocus(): AppFocus {
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Check if we're viewing a shared chat
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
|
||||
// Check if we're on the user settings page
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
|
||||
@@ -570,14 +570,6 @@ export default function useChatController({
|
||||
? messageToResend?.message || message
|
||||
: message;
|
||||
|
||||
// When editing a message that had files attached, preserve the original files.
|
||||
// Skip for regeneration — the regeneration path reuses the existing user node
|
||||
// (and its files), so merging here would send duplicates.
|
||||
const effectiveFileDescriptors = [
|
||||
...projectFilesToFileDescriptors(currentMessageFiles),
|
||||
...(!regenerationRequest ? messageToResend?.files ?? [] : []),
|
||||
];
|
||||
|
||||
updateChatStateAction(frozenSessionId, "loading");
|
||||
|
||||
// find the parent
|
||||
@@ -616,7 +608,7 @@ export default function useChatController({
|
||||
const result = buildImmediateMessages(
|
||||
parentNodeIdForMessage,
|
||||
currMessage,
|
||||
effectiveFileDescriptors,
|
||||
projectFilesToFileDescriptors(currentMessageFiles),
|
||||
messageToResend
|
||||
);
|
||||
initialUserNode = result.initialUserNode;
|
||||
@@ -653,7 +645,7 @@ export default function useChatController({
|
||||
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
let files = effectiveFileDescriptors;
|
||||
let files = projectFilesToFileDescriptors(currentMessageFiles);
|
||||
let packets: Packet[] = [];
|
||||
let packetsVersion = 0;
|
||||
|
||||
@@ -691,7 +683,7 @@ export default function useChatController({
|
||||
updateCurrentMessageFIFO(stack, {
|
||||
signal: controller.signal,
|
||||
message: currMessage,
|
||||
fileDescriptors: effectiveFileDescriptors,
|
||||
fileDescriptors: projectFilesToFileDescriptors(currentMessageFiles),
|
||||
parentMessageId: (() => {
|
||||
const parentId =
|
||||
regenerationRequest?.parentMessage.messageId ||
|
||||
@@ -764,7 +756,7 @@ export default function useChatController({
|
||||
posthog.capture("extension_chat_query", {
|
||||
extension_context: extensionContext,
|
||||
assistant_id: liveAssistant?.id,
|
||||
has_files: effectiveFileDescriptors.length > 0,
|
||||
has_files: currentMessageFiles.length > 0,
|
||||
deep_research: deepResearch,
|
||||
});
|
||||
}
|
||||
@@ -907,7 +899,12 @@ export default function useChatController({
|
||||
nodeId: initialUserNode.nodeId,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: effectiveFileDescriptors,
|
||||
files: currentMessageFiles.map((file) => ({
|
||||
id: file.file_id,
|
||||
type: file.chat_file_type,
|
||||
name: file.name,
|
||||
user_file_id: file.id,
|
||||
})),
|
||||
toolCall: null,
|
||||
parentNodeId: parentMessage?.nodeId || SYSTEM_NODE_ID,
|
||||
packets: [],
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
const SELECTOR = "[data-main-container]";
|
||||
|
||||
interface ContainerCenter {
|
||||
centerX: number | null;
|
||||
centerY: number | null;
|
||||
hasContainerCenter: boolean;
|
||||
}
|
||||
|
||||
function measure(el: HTMLElement): { x: number; y: number } {
|
||||
const rect = el.getBoundingClientRect();
|
||||
return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks the center point of the `[data-main-container]` element so that
|
||||
* portaled overlays (modals, command menus) can center relative to the main
|
||||
* content area rather than the full viewport.
|
||||
*
|
||||
* Returns `{ centerX, centerY, hasContainerCenter }`.
|
||||
* When the container is not present (e.g. pages without `AppLayouts.Root`),
|
||||
* both center values are `null` and `hasContainerCenter` is `false`, allowing
|
||||
* callers to fall back to standard viewport centering.
|
||||
*
|
||||
* Uses a lazy `useState` initializer so the first render already has the
|
||||
* correct values (no flash), and a `ResizeObserver` to stay reactive when
|
||||
* the sidebar folds/unfolds.
|
||||
*/
|
||||
export default function useContainerCenter(): ContainerCenter {
|
||||
const [center, setCenter] = useState<{ x: number | null; y: number | null }>(
|
||||
() => {
|
||||
if (typeof document === "undefined") return { x: null, y: null };
|
||||
const el = document.querySelector<HTMLElement>(SELECTOR);
|
||||
if (!el) return { x: null, y: null };
|
||||
const m = measure(el);
|
||||
return { x: m.x, y: m.y };
|
||||
}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const container = document.querySelector<HTMLElement>(SELECTOR);
|
||||
if (!container) return;
|
||||
|
||||
const update = () => {
|
||||
const m = measure(container);
|
||||
setCenter({ x: m.x, y: m.y });
|
||||
};
|
||||
|
||||
update();
|
||||
const observer = new ResizeObserver(update);
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, []);
|
||||
|
||||
return {
|
||||
centerX: center.x,
|
||||
centerY: center.y,
|
||||
hasContainerCenter: center.x !== null && center.y !== null,
|
||||
};
|
||||
}
|
||||
@@ -27,7 +27,7 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import { useCallback, useMemo, useState, useEffect } from "react";
|
||||
import { useAppBackground } from "@/providers/AppBackgroundProvider";
|
||||
import { useTheme } from "next-themes";
|
||||
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
|
||||
import ShareChatSessionModal from "@/app/app/components/modal/ShareChatSessionModal";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { useProjectsContext } from "@/providers/ProjectsContext";
|
||||
@@ -112,10 +112,6 @@ function Header() {
|
||||
|
||||
const customHeaderContent =
|
||||
settings?.enterpriseSettings?.custom_header_content;
|
||||
// Some pages don't want the custom header content, namely every page except Chat, Search, and
|
||||
// NewSession. The header provides features such as the open sidebar button on mobile which pages
|
||||
// without this content still use.
|
||||
const pageWithHeaderContent = appFocus.isChat() || appFocus.isNewSession();
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
|
||||
@@ -362,7 +358,7 @@ function Header() {
|
||||
*/}
|
||||
<div className="flex-1 flex flex-col items-center overflow-hidden">
|
||||
<Text text03 className="text-center w-full">
|
||||
{pageWithHeaderContent && customHeaderContent}
|
||||
{customHeaderContent}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -379,7 +375,6 @@ function Header() {
|
||||
transient={showShareModal}
|
||||
tertiary
|
||||
onClick={() => setShowShareModal(true)}
|
||||
aria-label="share-chat-button"
|
||||
>
|
||||
Share Chat
|
||||
</Button>
|
||||
@@ -515,12 +510,8 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
return (
|
||||
/* NOTE: Some elements, markdown tables in particular, refer to this `@container` in order to
|
||||
breakout of their immediate containers using cqw units.
|
||||
The `data-main-container` attribute is used by portaled elements (e.g. CommandMenu) to
|
||||
render inside this container so they can be centered relative to the main content area
|
||||
rather than the full viewport (which would include the sidebar).
|
||||
*/
|
||||
<div
|
||||
data-main-container
|
||||
className={cn(
|
||||
"@container flex flex-col h-full w-full relative overflow-hidden",
|
||||
showBackground && "bg-cover bg-center bg-fixed"
|
||||
@@ -573,7 +564,7 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
)}
|
||||
|
||||
<div className="z-app-layout">
|
||||
{!appFocus.isSharedChat() && <Header />}
|
||||
<Header />
|
||||
</div>
|
||||
<div className="z-app-layout flex-1 overflow-auto h-full w-full">
|
||||
{children}
|
||||
|
||||
@@ -191,8 +191,11 @@ function SettingsHeader({
|
||||
}: SettingsHeaderProps) {
|
||||
const [showShadow, setShowShadow] = useState(false);
|
||||
const headerRef = useRef<HTMLDivElement>(null);
|
||||
const isSticky = !!rightChildren; //headers with actions are always sticky, others are not
|
||||
|
||||
useEffect(() => {
|
||||
if (!isSticky) return;
|
||||
|
||||
// IMPORTANT: This component relies on SettingsRoot having the ID "page-wrapper-scroll-container"
|
||||
// on its scrollable container. If that ID is removed or changed, the scroll shadow will not work.
|
||||
const scrollContainer = document.getElementById(
|
||||
@@ -209,14 +212,15 @@ function SettingsHeader({
|
||||
handleScroll(); // Check initial state
|
||||
|
||||
return () => scrollContainer.removeEventListener("scroll", handleScroll);
|
||||
}, []);
|
||||
}, [isSticky]);
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={headerRef}
|
||||
className={cn(
|
||||
"sticky top-0 z-settings-header w-full bg-background-tint-01",
|
||||
backButton ? "md:pt-4" : "md:pt-10"
|
||||
"w-full bg-background-tint-01",
|
||||
isSticky && "sticky top-0 z-settings-header",
|
||||
backButton ? "pt-4" : "pt-10"
|
||||
)}
|
||||
>
|
||||
{backButton && (
|
||||
@@ -256,15 +260,18 @@ function SettingsHeader({
|
||||
<Separator noPadding className="px-4" />
|
||||
</>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"absolute left-0 right-0 h-[0.5rem] pointer-events-none transition-opacity duration-300 rounded-b-08 opacity-0",
|
||||
showShadow && "opacity-100"
|
||||
)}
|
||||
style={{
|
||||
background: "linear-gradient(to bottom, var(--mask-02), transparent)",
|
||||
}}
|
||||
/>
|
||||
{isSticky && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute left-0 right-0 h-[0.5rem] pointer-events-none transition-opacity duration-300 rounded-b-08 opacity-0",
|
||||
showShadow && "opacity-100"
|
||||
)}
|
||||
style={{
|
||||
background:
|
||||
"linear-gradient(to bottom, var(--mask-02), transparent)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,7 +65,6 @@ function ToastContainer() {
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="toast-container"
|
||||
className={cn(
|
||||
"fixed bottom-4 right-4 z-[10000]",
|
||||
"flex flex-col gap-2 items-end"
|
||||
|
||||
@@ -9,7 +9,6 @@ import { Button } from "@opal/components";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import { WithoutStyles } from "@/types";
|
||||
import { Section, SectionProps } from "@/layouts/general-layouts";
|
||||
import useContainerCenter from "@/hooks/useContainerCenter";
|
||||
|
||||
/**
|
||||
* Modal Root Component
|
||||
@@ -265,8 +264,6 @@ const ModalContent = React.forwardRef<
|
||||
contentRef(node);
|
||||
};
|
||||
|
||||
const { centerX, centerY, hasContainerCenter } = useContainerCenter();
|
||||
|
||||
const animationClasses = cn(
|
||||
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0",
|
||||
"data-[state=open]:zoom-in-95 data-[state=closed]:zoom-out-95",
|
||||
@@ -274,22 +271,6 @@ const ModalContent = React.forwardRef<
|
||||
"duration-200"
|
||||
);
|
||||
|
||||
const containerStyle: React.CSSProperties | undefined = hasContainerCenter
|
||||
? ({
|
||||
left: centerX,
|
||||
top: centerY,
|
||||
"--tw-enter-translate-x": "-50%",
|
||||
"--tw-exit-translate-x": "-50%",
|
||||
"--tw-enter-translate-y": "-50%",
|
||||
"--tw-exit-translate-y": "-50%",
|
||||
} as React.CSSProperties)
|
||||
: undefined;
|
||||
|
||||
const positionClasses = cn(
|
||||
"fixed -translate-x-1/2 -translate-y-1/2",
|
||||
!hasContainerCenter && "left-1/2 top-1/2"
|
||||
);
|
||||
|
||||
const dialogEventHandlers = {
|
||||
onOpenAutoFocus: (e: Event) => {
|
||||
resetState();
|
||||
@@ -334,9 +315,8 @@ const ModalContent = React.forwardRef<
|
||||
{...dialogEventHandlers}
|
||||
>
|
||||
<div
|
||||
style={containerStyle}
|
||||
className={cn(
|
||||
positionClasses,
|
||||
"fixed left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2",
|
||||
"z-modal",
|
||||
"flex flex-col gap-4 items-center",
|
||||
"max-w-[calc(100dvw-2rem)] max-h-[calc(100dvh-2rem)]",
|
||||
@@ -354,10 +334,8 @@ const ModalContent = React.forwardRef<
|
||||
// Without bottomSlot: original single-element rendering
|
||||
<DialogPrimitive.Content
|
||||
ref={handleRef}
|
||||
style={containerStyle}
|
||||
className={cn(
|
||||
positionClasses,
|
||||
"overflow-hidden",
|
||||
"fixed left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2 overflow-hidden",
|
||||
"z-modal",
|
||||
background === "gray"
|
||||
? "bg-background-tint-01"
|
||||
|
||||
@@ -32,7 +32,7 @@ const sizeClasses = {
|
||||
container: "rounded-04 p-0.5 gap-0.5",
|
||||
},
|
||||
tag: {
|
||||
container: "rounded-08 h-[2.25rem] min-w-[2.25rem] p-2 gap-1",
|
||||
container: "rounded-08 p-1 gap-1",
|
||||
},
|
||||
} as const;
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import React, {
|
||||
} from "react";
|
||||
import * as DialogPrimitive from "@radix-ui/react-dialog";
|
||||
import * as VisuallyHidden from "@radix-ui/react-visually-hidden";
|
||||
import useContainerCenter from "@/hooks/useContainerCenter";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
@@ -367,11 +366,10 @@ const CommandMenuContent = React.forwardRef<
|
||||
CommandMenuContentProps
|
||||
>(({ children }, ref) => {
|
||||
const { handleKeyDown } = useCommandMenuContext();
|
||||
const { centerX, hasContainerCenter } = useContainerCenter();
|
||||
|
||||
return (
|
||||
<DialogPrimitive.Portal>
|
||||
{/* Overlay - fixed to full viewport, hidden from assistive technology */}
|
||||
{/* Overlay - hidden from assistive technology */}
|
||||
<DialogPrimitive.Overlay
|
||||
aria-hidden="true"
|
||||
className={cn(
|
||||
@@ -380,23 +378,12 @@ const CommandMenuContent = React.forwardRef<
|
||||
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0"
|
||||
)}
|
||||
/>
|
||||
{/* Content - centered within the main container when available,
|
||||
otherwise falls back to viewport centering */}
|
||||
{/* Content */}
|
||||
<DialogPrimitive.Content
|
||||
ref={ref}
|
||||
onKeyDown={handleKeyDown}
|
||||
style={
|
||||
hasContainerCenter
|
||||
? ({
|
||||
left: centerX,
|
||||
"--tw-enter-translate-x": "-50%",
|
||||
"--tw-exit-translate-x": "-50%",
|
||||
} as React.CSSProperties)
|
||||
: undefined
|
||||
}
|
||||
className={cn(
|
||||
"fixed top-[72px]",
|
||||
hasContainerCenter ? "-translate-x-1/2" : "inset-x-0 mx-auto",
|
||||
"fixed inset-x-0 top-[72px] mx-auto",
|
||||
"z-modal",
|
||||
"bg-background-tint-00 border rounded-16 shadow-2xl outline-none",
|
||||
"flex flex-col overflow-hidden",
|
||||
|
||||
@@ -33,7 +33,7 @@ import { LLMOption, LLMOptionGroup } from "./interfaces";
|
||||
|
||||
export interface LLMPopoverProps {
|
||||
llmManager: LlmManager;
|
||||
requiresImageGeneration?: boolean;
|
||||
requiresImageInput?: boolean;
|
||||
folded?: boolean;
|
||||
onSelect?: (value: string) => void;
|
||||
currentModelName?: string;
|
||||
@@ -140,6 +140,7 @@ export function groupLlmOptions(
|
||||
|
||||
export default function LLMPopover({
|
||||
llmManager,
|
||||
requiresImageInput,
|
||||
folded,
|
||||
onSelect,
|
||||
currentModelName,
|
||||
@@ -186,19 +187,23 @@ export default function LLMPopover({
|
||||
[llmProviders, currentModelName]
|
||||
);
|
||||
|
||||
// Filter options by search query
|
||||
// Filter options by vision capability (when images are uploaded) and search query
|
||||
const filteredOptions = useMemo(() => {
|
||||
if (!searchQuery.trim()) {
|
||||
return llmOptions;
|
||||
let result = llmOptions;
|
||||
if (requiresImageInput) {
|
||||
result = result.filter((opt) => opt.supportsImageInput);
|
||||
}
|
||||
const query = searchQuery.toLowerCase();
|
||||
return llmOptions.filter(
|
||||
(opt) =>
|
||||
opt.displayName.toLowerCase().includes(query) ||
|
||||
opt.modelName.toLowerCase().includes(query) ||
|
||||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
|
||||
);
|
||||
}, [llmOptions, searchQuery]);
|
||||
if (searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase();
|
||||
result = result.filter(
|
||||
(opt) =>
|
||||
opt.displayName.toLowerCase().includes(query) ||
|
||||
opt.modelName.toLowerCase().includes(query) ||
|
||||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}, [llmOptions, searchQuery, requiresImageInput]);
|
||||
|
||||
// Group options by provider using backend-provided display names and ordering
|
||||
// For aggregator providers (bedrock, openrouter, vertex_ai), flatten to "Provider/Vendor" format
|
||||
|
||||
@@ -643,7 +643,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
? "0fr auto 1fr"
|
||||
: appFocus.isChat()
|
||||
? "1fr auto 0fr"
|
||||
: "1fr auto 1fr",
|
||||
: appFocus.isProject()
|
||||
? "auto auto 1fr"
|
||||
: "1fr auto 1fr",
|
||||
};
|
||||
|
||||
if (!isReady) return <OnyxInitializingLoader />;
|
||||
@@ -698,7 +700,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
<FederatedOAuthModal />
|
||||
|
||||
<AppLayouts.Root enableBackground>
|
||||
<AppLayouts.Root enableBackground={!appFocus.isProject()}>
|
||||
<Dropzone
|
||||
onDrop={(acceptedFiles) =>
|
||||
handleMessageSpecificFileUpload(acceptedFiles)
|
||||
@@ -751,11 +753,13 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
{/* ProjectUI */}
|
||||
{appFocus.isProject() && (
|
||||
<ProjectContextPanel
|
||||
projectTokenCount={projectContextTokenCount}
|
||||
availableContextTokens={availableContextTokens}
|
||||
setPresentingDocument={setPresentingDocument}
|
||||
/>
|
||||
<div className="w-full max-h-[50vh] overflow-y-auto overscroll-y-none">
|
||||
<ProjectContextPanel
|
||||
projectTokenCount={projectContextTokenCount}
|
||||
availableContextTokens={availableContextTokens}
|
||||
setPresentingDocument={setPresentingDocument}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* WelcomeMessageUI */}
|
||||
@@ -872,19 +876,18 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* ProjectChatSessionsUI */}
|
||||
{appFocus.isProject() && (
|
||||
<>
|
||||
<Spacer rem={0.5} />
|
||||
<ProjectChatSessionList />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* ── Bottom: SearchResults + SourceFilter / Suggestions ── */}
|
||||
{/* ── Bottom: SearchResults + SourceFilter / Suggestions / ProjectChatList ── */}
|
||||
<div className="row-start-3 min-h-0 overflow-hidden flex flex-col items-center w-full">
|
||||
{/* ProjectChatSessionList */}
|
||||
{appFocus.isProject() && (
|
||||
<div className="w-full max-w-[var(--app-page-main-content-width)] h-full overflow-y-auto overscroll-y-none mx-auto">
|
||||
<ProjectChatSessionList />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* SuggestionsUI */}
|
||||
<Fade
|
||||
show={
|
||||
|
||||
@@ -346,7 +346,6 @@ const ChatScrollContainer = React.memo(
|
||||
<div
|
||||
key={sessionId}
|
||||
ref={scrollContainerRef}
|
||||
data-testid="chat-scroll-container"
|
||||
className="flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden default-scrollbar"
|
||||
onScroll={handleScroll}
|
||||
style={{
|
||||
|
||||
@@ -22,7 +22,7 @@ import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { getFormattedDateRangeString } from "@/lib/dateUtils";
|
||||
import { truncateString, cn } from "@/lib/utils";
|
||||
import { truncateString, cn, isImageFile } from "@/lib/utils";
|
||||
import { Disabled } from "@/refresh-components/Disabled";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
@@ -383,6 +383,11 @@ const AppInputBar = React.memo(
|
||||
return currentMessageFiles.length > 1;
|
||||
}, [currentMessageFiles]);
|
||||
|
||||
const hasImageFiles = useMemo(
|
||||
() => currentMessageFiles.some((f) => isImageFile(f.name)),
|
||||
[currentMessageFiles]
|
||||
);
|
||||
|
||||
// Check if the assistant has search tools available (internal search or web search)
|
||||
// AND if deep research is globally enabled in admin settings
|
||||
const showDeepResearch = useMemo(() => {
|
||||
@@ -754,7 +759,7 @@ const AppInputBar = React.memo(
|
||||
>
|
||||
<LLMPopover
|
||||
llmManager={llmManager}
|
||||
requiresImageGeneration={false}
|
||||
requiresImageInput={hasImageFiles}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user