Compare commits

..

15 Commits

Author SHA1 Message Date
Dane Urban
c577c1bde7 . 2026-02-18 14:40:16 -08:00
Dane Urban
35a8e8d3ea . 2026-02-18 14:18:07 -08:00
Dane Urban
1d300a517f . 2026-02-18 14:15:27 -08:00
Dane Urban
82f7bd26f8 Fix 2026-02-18 14:14:54 -08:00
Dane Urban
d9535b0332 Move code 2026-02-17 17:04:40 -08:00
Danelegend
c224a8f578 Merge branch 'main' into ci_artifacts 2026-02-17 17:04:07 -08:00
Dane Urban
098c5ae8f8 . 2026-02-17 16:59:17 -08:00
Dane Urban
3d912d8e2e Some small changes 2026-02-17 16:53:44 -08:00
Dane Urban
ff5c5431be Merge branch 'linguist-languages' into ci_artifacts 2026-02-17 16:22:29 -08:00
Jamison Lahman
50f1ed2e01 npm install on a linux device 2026-02-17 16:12:32 -08:00
Dane Urban
0b45f36742 Add linguist-language 2026-02-17 16:03:02 -08:00
Dane Urban
e7f1786de7 Update 2026-02-17 15:57:24 -08:00
Dane Urban
4843169131 nits 2026-02-17 15:23:32 -08:00
Dane Urban
cd3bffe984 same origin url 2026-02-17 14:27:06 -08:00
Dane Urban
2f880468cb Display CI artifacts better 2026-02-17 14:23:30 -08:00
75 changed files with 2178 additions and 6209 deletions

View File

@@ -51,9 +51,8 @@ All tests use `admin_auth.json` storage state by default (pre-authenticated admi
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
- Server readiness check (polls health endpoint, 60s timeout)
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
- Setting display name to `"worker"` for each worker user
- Provisioning test users: admin, user, admin2 (idempotent)
- API login + saving storage states: `admin_auth.json`, `user_auth.json`, `admin2_auth.json`
- Promoting admin2 to admin role
- Ensuring a public LLM provider exists
@@ -65,12 +64,8 @@ When a test needs a different user, use API-based login — never drive the logi
import { loginAs } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAs(page, "user");
await loginAs(page, "admin2");
// Log in as the worker-specific user (preferred for test isolation):
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
```
## Test Structure
@@ -89,21 +84,18 @@ test.describe("Feature Name", () => {
});
```
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should use `loginAsRandomUser` to get a fresh user per test. This prevents side effects from leaking into other parallel tests' screenshots and assertions:
```typescript
import { test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
import { loginAsRandomUser } from "@tests/e2e/utils/auth";
test.beforeEach(async ({ page }, testInfo) => {
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
await loginAsRandomUser(page);
});
```
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
Switch to admin only when privileged setup is needed (creating providers, configuring tools), then back to the isolated user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
@@ -134,30 +126,6 @@ Backend API client for test setup/teardown. Key methods:
- `expectElementScreenshot(locator, { name, mask?, hide? })`
- Controlled by `VISUAL_REGRESSION=true` env var
### `theme` (`@tests/e2e/utils/theme`)
- `THEMES``["light", "dark"] as const` array for iterating over both themes
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
```typescript
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
for (const theme of THEMES) {
test.describe(`Feature (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("renders correctly", async ({ page }) => {
await page.goto("/app");
await expectScreenshot(page, { name: `feature-${theme}` });
});
});
}
```
### `tools` (`@tests/e2e/utils/tools`)
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
@@ -238,10 +206,10 @@ await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.stat
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should use `loginAsRandomUser` for a fresh user per test, avoiding cross-test contamination. Always cleanup API-created resources in `afterAll`
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
6. **Parallel-safe** — no shared mutable state between tests; use unique names with timestamps (`\`test-${Date.now()}\``)
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks

View File

@@ -0,0 +1,151 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -41,7 +41,8 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
with:
uv_version: "0.9.9"

View File

@@ -593,10 +593,7 @@ jobs:
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: >-
always() &&
github.event_name == 'pull_request' &&
needs.playwright-tests.result != 'cancelled'
if: always() && github.event_name == 'pull_request'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:

View File

@@ -8,11 +8,11 @@ Usage from FastAPI::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.post("/tokens")
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
token = dal.create_token(name=..., hashed_token=..., ...)
@router.get("/tokens")
def list_tokens(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
tokens = dal.list_tokens()
dal.commit()
return token
return tokens
Usage from background tasks::
@@ -25,26 +25,13 @@ from __future__ import annotations
from uuid import UUID
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ScimDAL(DAL):
@@ -66,20 +53,7 @@ class ScimDAL(DAL):
token_display: str,
created_by_id: UUID,
) -> ScimToken:
"""Create a new SCIM bearer token.
Only one token is active at a time — this method automatically revokes
all existing active tokens before creating the new one.
"""
# Revoke any currently active tokens
active_tokens = list(
self._session.scalars(
select(ScimToken).where(ScimToken.is_active.is_(True))
).all()
)
for t in active_tokens:
t.is_active = False
"""Create a new SCIM bearer token."""
token = ScimToken(
name=name,
hashed_token=hashed_token,
@@ -90,18 +64,20 @@ class ScimDAL(DAL):
self._session.flush()
return token
def get_active_token(self) -> ScimToken | None:
"""Return the single currently active token, or None."""
return self._session.scalar(
select(ScimToken).where(ScimToken.is_active.is_(True))
)
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
"""Look up a token by its SHA-256 hash."""
return self._session.scalar(
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
)
def list_tokens(self) -> list[ScimToken]:
"""List all SCIM tokens, ordered by creation date descending."""
return list(
self._session.scalars(
select(ScimToken).order_by(ScimToken.created_at.desc())
).all()
)
def revoke_token(self, token_id: int) -> None:
"""Deactivate a token by ID.
@@ -195,130 +171,15 @@ class ScimDAL(DAL):
return mapping
def delete_user_mapping(self, mapping_id: int) -> None:
"""Delete a user mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# User query operations
# ------------------------------------------------------------------
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
"""Fetch a user by email (case-insensitive)."""
return self._session.scalar(
select(User).where(func.lower(User.email) == func.lower(email))
)
def add_user(self, user: User) -> None:
"""Add a new user to the session and flush to assign an ID."""
self._session.add(user)
self._session.flush()
def update_user(
self,
user: User,
*,
email: str | None = None,
is_active: bool | None = None,
personal_name: str | None = None,
) -> None:
"""Update user attributes. Only sets fields that are provided."""
if email is not None:
user.email = email
if is_active is not None:
user.is_active = is_active
if personal_name is not None:
user.personal_name = personal_name
def deactivate_user(self, user: User) -> None:
"""Mark a user as inactive."""
user.is_active = False
def list_users(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
"""Delete a user mapping by ID.
Raises:
ValueError: If the filter uses an unsupported attribute.
ValueError: If the mapping does not exist.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
# Count total matching rows first, then paginate. SCIM uses 1-based
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group mapping operations
@@ -385,220 +246,12 @@ class ScimDAL(DAL):
return mappings, total
def delete_group_mapping(self, mapping_id: int) -> None:
"""Delete a group mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group query operations
# ------------------------------------------------------------------
def get_group(self, group_id: int) -> UserGroup | None:
"""Fetch a group by ID, returning None if deleted or missing."""
group = self._session.get(UserGroup, group_id)
if group and group.is_up_for_deletion:
return None
return group
def get_group_by_name(self, name: str) -> UserGroup | None:
"""Fetch a group by exact name."""
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
def add_group(self, group: UserGroup) -> None:
"""Add a new group to the session and flush to assign an ID."""
self._session.add(group)
self._session.flush()
def update_group(
self,
group: UserGroup,
*,
name: str | None = None,
) -> None:
"""Update group attributes and set the modification timestamp."""
if name is not None:
group.name = name
group.time_last_modified_by_user = func.now()
def delete_group(self, group: UserGroup) -> None:
"""Delete a group from the session."""
self._session.delete(group)
def list_groups(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[UserGroup, str | None]], int]:
"""Query groups with optional SCIM filter and pagination.
Returns:
A tuple of (list of (group, external_id) pairs, total_count).
"""Delete a group mapping by ID.
Raises:
ValueError: If the filter uses an unsupported attribute.
ValueError: If the mapping does not exist.
"""
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(UserGroup.id == mapping.user_group_id)
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
groups = list(
self._session.scalars(
query.order_by(UserGroup.id).offset(offset).limit(count)
).all()
)
ext_id_map = self._get_group_external_ids([g.id for g in groups])
return [(g, ext_id_map.get(g.id)) for g in groups], total
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
"""Get group members as (user_id, email) pairs."""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
).all()
user_ids = [r.user_id for r in rels if r.user_id]
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users_by_id = {u.id: u for u in users}
return [
(
r.user_id,
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
)
for r in rels
if r.user_id
]
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
"""Return the subset of UUIDs that don't exist as users.
Returns an empty list if all IDs are valid.
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Add user-group relationships, ignoring duplicates."""
if not user_ids:
return
self._session.execute(
pg_insert(User__UserGroup)
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
.on_conflict_do_nothing(
index_elements=[
User__UserGroup.user_group_id,
User__UserGroup.user_id,
]
)
)
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Replace all members of a group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
)
self.upsert_group_members(group_id, user_ids)
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Remove specific members from a group."""
if not user_ids:
return
self._session.execute(
sa_delete(User__UserGroup).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.in_(user_ids),
)
)
def delete_group_with_members(self, group: UserGroup) -> None:
"""Remove all member relationships and delete the group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
)
self._session.delete(group)
def sync_group_external_id(
self, group_id: int, new_external_id: str | None
) -> None:
"""Create, update, or delete the external ID mapping for a group."""
mapping = self.get_group_mapping_by_group_id(group_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_group_mapping(
external_id=new_external_id, user_group_id=group_id
)
elif mapping:
self.delete_group_mapping(mapping.id)
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
"""Batch-fetch external IDs for a list of group IDs."""
if not group_ids:
return {}
mappings = self._session.scalars(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id.in_(group_ids)
)
).all()
return {m.user_group_id: m.external_id for m in mappings}
# ---------------------------------------------------------------------------
# Module-level helpers (used by DAL methods above)
# ---------------------------------------------------------------------------
def _apply_scim_string_op(
query: Select[tuple[User]] | Select[tuple[UserGroup]],
column: SQLColumnExpression[str],
scim_filter: ScimFilter,
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
SQLAlchemy's operators handle LIKE-pattern escaping internally.
"""
val = scim_filter.value
if scim_filter.operator == ScimFilterOperator.EQUAL:
return query.where(func.lower(column) == val.lower())
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
return query.where(column.icontains(val, autoescape=True))
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
return query.where(column.istartswith(val, autoescape=True))
else:
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM group mapping with id {mapping_id} not found")
self._session.delete(mapping)

View File

@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -163,11 +162,6 @@ def get_application() -> FastAPI:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -5,11 +5,6 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
# before bearer token configuration is complete
("/scim/v2/ServiceProviderConfig", {"GET"}),
("/scim/v2/ResourceTypes", {"GET"}),
("/scim/v2/Schemas", {"GET"}),
# needs to be accessible prior to user login
("/enterprise-settings", {"GET"}),
("/enterprise-settings/logo", {"GET"}),

View File

@@ -13,7 +13,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
@@ -23,10 +22,6 @@ from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
@@ -203,63 +198,3 @@ def upload_custom_analytics_script(
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()
# ---------------------------------------------------------------------------
# SCIM token management
# ---------------------------------------------------------------------------
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
return ScimTokenResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
)
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.
Only one token is active at a time — creating a new token automatically
revokes all previous tokens. The raw token value is returned exactly once
in the response; it cannot be retrieved again.
"""
raw_token, hashed_token, token_display = generate_scim_token()
token = dal.create_token(
name=body.name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=user.id,
)
dal.commit()
return ScimTokenCreatedResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
raw_token=raw_token,
)

View File

@@ -1,689 +0,0 @@
"""SCIM 2.0 API endpoints (RFC 7644).
This module provides the FastAPI router for SCIM service discovery,
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
these endpoints to provision and manage users and groups.
Service discovery endpoints are unauthenticated — IdPs may probe them
before bearer token configuration is complete. All other endpoints
require a valid SCIM bearer token.
"""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@scim_router.get("/ServiceProviderConfig")
def get_service_provider_config() -> ScimServiceProviderConfig:
"""Advertise supported SCIM features (RFC 7643 §5)."""
return SERVICE_PROVIDER_CONFIG
@scim_router.get("/ResourceTypes")
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
@scim_router.get("/Schemas")
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> JSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
body = ScimError(status=str(status), detail=detail)
return JSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
except ValueError:
return _scim_error_response(404, f"User {user_id} not found")
user = dal.get_user(uid)
if not user:
return _scim_error_response(404, f"User {user_id} not found")
return user
def _scim_name_to_str(name: ScimName | None) -> str | None:
"""Extract a display name string from a SCIM name object.
Returns None if no name is provided, so the caller can decide
whether to update the user's personal_name.
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# ---------------------------------------------------------------------------
# User CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
)
try:
dal.add_user(user)
except IntegrityError:
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
dal.commit()
return _user_to_scim(user, external_id)
@scim_router.put("/Users/{user_id}", response_model=None)
def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
# Handle activation (need seat check) / deactivation
if user_resource.active and not user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
dal.commit()
return _user_to_scim(user, new_external_id)
@scim_router.patch("/Users/{user_id}", response_model=None)
def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
``PATCH {"active": false}`` rather than DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current = _user_to_scim(user, external_id)
try:
patched = apply_user_patch(patch_request.Operations, current)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.commit()
return _user_to_scim(user, patched.externalId)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
return result
user = result
dal.deactivate_user(user)
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
dal.commit()
return Response(status_code=204)
# ---------------------------------------------------------------------------
# Group helpers
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
except ValueError:
return _scim_error_response(404, f"Group {group_id} not found")
group = dal.get_group(gid)
if not group:
return _scim_error_response(404, f"Group {group_id} not found")
return group
def _parse_member_uuids(
members: list[ScimGroupMember],
) -> tuple[list[UUID], str | None]:
"""Parse member value strings to UUIDs.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids: list[UUID] = []
for m in members:
try:
uuids.append(UUID(m.value))
except ValueError:
return [], f"Invalid member ID: {m.value}"
return uuids, None
def _validate_and_parse_members(
members: list[ScimGroupMember], dal: ScimDAL
) -> tuple[list[UUID], str | None]:
"""Parse and validate member UUIDs exist in the database.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids, err = _parse_member_uuids(members)
if err:
return [], err
if uuids:
missing = dal.validate_member_ids(uuids)
if missing:
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
return uuids, None
# ---------------------------------------------------------------------------
# Group CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
db_group = UserGroup(
name=group_resource.displayName,
is_up_to_date=True,
time_last_modified_by_user=func.now(),
)
try:
dal.add_group(db_group)
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
dal.upsert_group_members(db_group.id, member_uuids)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
dal.update_group(group, name=new_name)
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
missing = dal.validate_member_ids(add_uuids)
if missing:
return _scim_error_response(
400,
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
dal.commit()
return Response(status_code=204)
def _is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
try:
UUID(value)
return True
except ValueError:
return False

View File

@@ -30,7 +30,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
# ---------------------------------------------------------------------------
@@ -196,39 +195,10 @@ class ScimServiceProviderConfig(BaseModel):
)
class ScimSchemaAttribute(BaseModel):
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
name: str
type: str
multiValued: bool = False
required: bool = False
description: str = ""
caseExact: bool = False
mutability: str = "readWrite"
returned: str = "default"
uniqueness: str = "none"
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
class ScimSchemaDefinition(BaseModel):
"""SCIM Schema definition (RFC 7643 §7).
Served at GET /scim/v2/Schemas. Describes the attributes available
on each resource type so IdPs know which fields they can provision.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
id: str
name: str
description: str
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
model_config = ConfigDict(populate_by_name=True)
schema_: str = Field(alias="schema")
required: bool
@@ -241,7 +211,7 @@ class ScimResourceType(BaseModel):
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str

View File

@@ -1,144 +0,0 @@
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaAttribute
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "User",
"name": "User",
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
}
)
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "Group",
"name": "Group",
"endpoint": "/scim/v2/Groups",
"description": "SCIM Group resource",
"schema": SCIM_GROUP_SCHEMA,
}
)
USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_USER_SCHEMA,
name="User",
description="SCIM core User schema",
attributes=[
ScimSchemaAttribute(
name="userName",
type="string",
required=True,
uniqueness="server",
description="Unique identifier for the user, typically an email address.",
),
ScimSchemaAttribute(
name="name",
type="complex",
description="The components of the user's name.",
subAttributes=[
ScimSchemaAttribute(
name="givenName",
type="string",
description="The user's first name.",
),
ScimSchemaAttribute(
name="familyName",
type="string",
description="The user's last name.",
),
ScimSchemaAttribute(
name="formatted",
type="string",
description="The full name, including all middle names and titles.",
),
],
),
ScimSchemaAttribute(
name="emails",
type="complex",
multiValued=True,
description="Email addresses for the user.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Email address value.",
),
ScimSchemaAttribute(
name="type",
type="string",
description="Label for this email (e.g. 'work').",
),
ScimSchemaAttribute(
name="primary",
type="boolean",
description="Whether this is the primary email.",
),
],
),
ScimSchemaAttribute(
name="active",
type="boolean",
description="Whether the user account is active.",
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",
description="SCIM core Group schema",
attributes=[
ScimSchemaAttribute(
name="displayName",
type="string",
required=True,
description="Human-readable name for the group.",
),
ScimSchemaAttribute(
name="members",
type="complex",
multiValued=True,
description="Members of the group.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="User ID of the group member.",
),
ScimSchemaAttribute(
name="display",
type="string",
mutability="readOnly",
description="Display name of the group member.",
),
],
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)

View File

@@ -68,18 +68,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
"""Detect XML-style marshaled tool calls emitted as plain text."""
if not text:
return False
lowered = text.lower()
return (
"<function_calls" in lowered
and "<invoke" in lowered
and "<parameter" in lowered
)
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
@@ -134,56 +122,38 @@ def _try_fallback_tool_extraction(
reasoning_but_no_answer_or_tools = (
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
)
xml_tool_call_text_detected = no_tool_calls and (
_looks_like_xml_tool_call_payload(llm_step_result.answer)
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
)
should_try_fallback = (
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
or reasoning_but_no_answer_or_tools
or xml_tool_call_text_detected
)
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
) or reasoning_but_no_answer_or_tools
if not should_try_fallback:
return llm_step_result, False
# Try to extract from answer first, then fall back to reasoning
extracted_tool_calls: list[ToolCallKickoff] = []
if llm_step_result.answer:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if (
not extracted_tool_calls
and llm_step_result.raw_answer
and llm_step_result.raw_answer != llm_step_result.answer
):
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.raw_answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if not extracted_tool_calls and llm_step_result.reasoning:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.reasoning,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if extracted_tool_calls:
logger.info(
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
"as fallback"
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
)
return (
LlmStepResult(
reasoning=llm_step_result.reasoning,
answer=llm_step_result.answer,
tool_calls=extracted_tool_calls,
raw_answer=llm_step_result.raw_answer,
),
True,
)

View File

@@ -1,12 +1,10 @@
import json
import re
import time
import uuid
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Sequence
from html import unescape
from typing import Any
from typing import cast
@@ -58,112 +56,6 @@ from onyx.utils.text_processing import find_all_json_objects
logger = setup_logger()
_XML_INVOKE_BLOCK_RE = re.compile(
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
re.IGNORECASE | re.DOTALL,
)
_XML_PARAMETER_RE = re.compile(
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
re.IGNORECASE | re.DOTALL,
)
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
class _XmlToolCallContentFilter:
"""Streaming filter that strips XML-style tool call payload blocks from text."""
def __init__(self) -> None:
self._pending = ""
self._inside_function_calls_block = False
def process(self, content: str) -> str:
if not content:
return ""
self._pending += content
output_parts: list[str] = []
while self._pending:
pending_lower = self._pending.lower()
if self._inside_function_calls_block:
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
if end_idx == -1:
# Keep buffering until we see the close marker.
return "".join(output_parts)
# Drop the whole function_calls block.
self._pending = self._pending[
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
]
self._inside_function_calls_block = False
continue
start_idx = _find_function_calls_open_marker(pending_lower)
if start_idx == -1:
# Keep only a possible prefix of "<function_calls" in the buffer so
# marker splits across chunks are handled correctly.
tail_len = _matching_open_marker_prefix_len(self._pending)
emit_upto = len(self._pending) - tail_len
if emit_upto > 0:
output_parts.append(self._pending[:emit_upto])
self._pending = self._pending[emit_upto:]
return "".join(output_parts)
if start_idx > 0:
output_parts.append(self._pending[:start_idx])
# Enter block-stripping mode and keep scanning for close marker.
self._pending = self._pending[start_idx:]
self._inside_function_calls_block = True
return "".join(output_parts)
def flush(self) -> str:
if self._inside_function_calls_block:
# Drop any incomplete block at stream end.
self._pending = ""
self._inside_function_calls_block = False
return ""
remaining = self._pending
self._pending = ""
return remaining
def _matching_open_marker_prefix_len(text: str) -> int:
"""Return longest suffix of text that matches prefix of "<function_calls"."""
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
text_lower = text.lower()
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
for candidate_len in range(max_len, 0, -1):
if text_lower.endswith(marker_lower[:candidate_len]):
return candidate_len
return 0
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
return char is None or char in {">", " ", "\t", "\n", "\r"}
def _find_function_calls_open_marker(text_lower: str) -> int:
"""Find '<function_calls' with a valid tag boundary follower."""
search_from = 0
while True:
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
if idx == -1:
return -1
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
if _is_valid_function_calls_open_follower(follower):
return idx
search_from = idx + 1
def _sanitize_llm_output(value: str) -> str:
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
@@ -380,7 +272,14 @@ def _extract_tool_call_kickoffs(
tab_index_calculated = 0
for tool_call_data in id_to_tool_call_map.values():
if tool_call_data.get("id") and tool_call_data.get("name"):
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
try:
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
except json.JSONDecodeError:
# If parsing fails, try empty dict, most tools would fail though
logger.error(
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
)
tool_args = {}
tool_calls.append(
ToolCallKickoff(
@@ -408,9 +307,8 @@ def extract_tool_calls_from_response_text(
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
This is a fallback mechanism for when the LLM was expected to return tool calls
but didn't use the proper tool call format. It searches for tool calls embedded
in response text (JSON first, then XML-like invoke blocks) that match available
tool definitions.
but didn't use the proper tool call format. It searches for JSON objects in the
response text that match the structure of available tools.
Args:
response_text: The LLM's text response to search for tool calls
@@ -435,9 +333,10 @@ def extract_tool_calls_from_response_text(
if not tool_name_to_def:
return []
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
# Find all JSON objects in the response text
json_objects = find_all_json_objects(response_text)
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
prev_json_obj: dict[str, Any] | None = None
prev_tool_call: tuple[str, dict[str, Any]] | None = None
@@ -465,14 +364,6 @@ def extract_tool_calls_from_response_text(
prev_json_obj = json_obj
prev_tool_call = matched_tool_call
# Some providers/models emit XML-style function calls instead of JSON objects.
# Keep this as a fallback behind JSON extraction to preserve current behavior.
if not matched_tool_calls:
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
response_text=response_text,
tool_name_to_def=tool_name_to_def,
)
tool_calls: list[ToolCallKickoff] = []
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
tool_calls.append(
@@ -495,88 +386,6 @@ def extract_tool_calls_from_response_text(
return tool_calls
def _extract_xml_tool_calls_from_response_text(
response_text: str,
tool_name_to_def: dict[str, dict],
) -> list[tuple[str, dict[str, Any]]]:
"""Extract XML-style tool calls from response text.
Supports formats such as:
<function_calls>
<invoke name="internal_search">
<parameter name="queries" string="false">["foo"]</parameter>
</invoke>
</function_calls>
"""
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
invoke_attrs = invoke_match.group("attrs")
tool_name = _extract_xml_attribute(invoke_attrs, "name")
if not tool_name or tool_name not in tool_name_to_def:
continue
tool_args: dict[str, Any] = {}
invoke_body = invoke_match.group("body")
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
parameter_attrs = parameter_match.group("attrs")
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
if not parameter_name:
continue
string_attr = _extract_xml_attribute(parameter_attrs, "string")
tool_args[parameter_name] = _parse_xml_parameter_value(
raw_value=parameter_match.group("value"),
string_attr=string_attr,
)
matched_tool_calls.append((tool_name, tool_args))
return matched_tool_calls
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
"""Extract a single XML-style attribute value from a tag attribute string."""
attr_match = re.search(
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
attrs,
flags=re.IGNORECASE | re.DOTALL,
)
if not attr_match:
return None
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
"""Parse a parameter value from XML-style tool call payloads."""
value = _sanitize_llm_output(unescape(raw_value).strip())
if string_attr and string_attr.lower() == "true":
return value
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
"""Extract and parse an arguments/parameters value from a tool-call-like object.
Looks for "arguments" or "parameters" keys, handles JSON-string values,
and returns a dict if successful, or None otherwise.
"""
arguments = obj.get("arguments", obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return arguments
return None
def _try_match_json_to_tool(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
@@ -599,8 +408,13 @@ def _try_match_json_to_tool(
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
tool_name = json_obj["name"]
arguments = _resolve_tool_arguments(json_obj)
if arguments is not None:
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
@@ -608,8 +422,13 @@ def _try_match_json_to_tool(
func_obj = json_obj["function"]
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
tool_name = func_obj["name"]
arguments = _resolve_tool_arguments(func_obj)
if arguments is not None:
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 3: Tool name as key {"tool_name": {...arguments...}}
@@ -879,8 +698,7 @@ def run_llm_step_pkt_generator(
tool_definitions: List of tool definitions available to the LLM.
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
llm: Language model interface to use for generation.
placement: Placement info (turn_index, tab_index, sub_turn_index) for
positioning packets in the conversation UI.
turn_index: Current turn index in the conversation.
state_container: Container for storing chat state (reasoning, answers).
citation_processor: Optional processor for extracting and formatting citations
from the response. If provided, processes tokens to identify citations.
@@ -892,14 +710,7 @@ def run_llm_step_pkt_generator(
custom_token_processor: Optional callable that processes each token delta
before yielding. Receives (delta, processor_state) and returns
(modified_delta, new_processor_state). Can return None for delta to skip.
max_tokens: Optional maximum number of tokens for the LLM response.
use_existing_tab_index: If True, use the tab_index from placement for all
tool calls instead of auto-incrementing.
is_deep_research: If True, treat content before tool calls as reasoning
when tool_choice is REQUIRED.
pre_answer_processing_time: Optional time spent processing before the
answer started, recorded in state_container for analytics.
timeout_override: Optional timeout override for the LLM call.
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
Yields:
Packet: Streaming packets containing:
@@ -925,15 +736,8 @@ def run_llm_step_pkt_generator(
tab_index = placement.tab_index
sub_turn_index = placement.sub_turn_index
def _current_placement() -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
llm_msg_history = translate_history_to_llm_format(history, llm.config)
has_reasoned = False
has_reasoned = 0
if LOG_ONYX_MODEL_INTERACTIONS:
logger.debug(
@@ -945,8 +749,6 @@ def run_llm_step_pkt_generator(
answer_start = False
accumulated_reasoning = ""
accumulated_answer = ""
accumulated_raw_answer = ""
xml_tool_call_content_filter = _XmlToolCallContentFilter()
processor_state: Any = None
@@ -962,112 +764,6 @@ def run_llm_step_pkt_generator(
)
stream_start_time = time.monotonic()
first_action_recorded = False
def _emit_citation_results(
results: Generator[str | CitationInfo, None, None],
) -> Generator[Packet, None, None]:
"""Yield packets for citation processor results (str or CitationInfo)."""
nonlocal accumulated_answer
for result in results:
if isinstance(result, str):
accumulated_answer += result
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=_current_placement(),
obj=result,
)
if state_container:
state_container.add_emitted_citation(result.citation_number)
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
nonlocal reasoning_start
nonlocal has_reasoned
nonlocal turn_index
nonlocal sub_turn_index
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = True
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
nonlocal accumulated_answer
nonlocal accumulated_reasoning
nonlocal answer_start
nonlocal reasoning_start
nonlocal turn_index
nonlocal sub_turn_index
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
accumulated_reasoning += content_chunk
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
obj=ReasoningDelta(reasoning=content_chunk),
)
reasoning_start = True
return
# Normal flow for AUTO or NONE tool choice
yield from _close_reasoning_if_active()
if not answer_start:
# Store pre-answer processing time in state container for save_chat
if state_container and pre_answer_processing_time is not None:
state_container.set_pre_answer_processing_time(
pre_answer_processing_time
)
yield Packet(
placement=_current_placement(),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
),
)
answer_start = True
if citation_processor:
yield from _emit_citation_results(
citation_processor.process_token(content_chunk)
)
else:
accumulated_answer += content_chunk
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=content_chunk),
)
for packet in llm.stream(
prompt=llm_msg_history,
tools=tool_definitions,
@@ -1126,34 +822,152 @@ def run_llm_step_pkt_generator(
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
reasoning_start = True
if delta.content:
# Keep raw content for fallback extraction. Display content can be
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
accumulated_raw_answer += delta.content
filtered_content = xml_tool_call_content_filter.process(delta.content)
if filtered_content:
yield from _emit_content_chunk(filtered_content)
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
# Treat content as reasoning when we know tool calls are coming
accumulated_reasoning += delta.content
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.content),
)
reasoning_start = True
else:
# Normal flow for AUTO or NONE tool choice
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
if not answer_start:
# Store pre-answer processing time in state container for save_chat
if state_container and pre_answer_processing_time is not None:
state_container.set_pre_answer_processing_time(
pre_answer_processing_time
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
),
)
answer_start = True
if citation_processor:
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(
accumulated_answer
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(
result.citation_number
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=delta.content),
)
if delta.tool_calls:
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
# Flush any tail text buffered while checking for split "<function_calls" markers.
filtered_content_tail = xml_tool_call_content_filter.flush()
if filtered_content_tail:
yield from _emit_content_chunk(filtered_content_tail)
# Flush custom token processor to get any final tool calls
if custom_token_processor:
flush_delta, processor_state = custom_token_processor(None, processor_state)
@@ -1209,14 +1023,50 @@ def run_llm_step_pkt_generator(
# This may happen if the custom token processor is used to modify other packets into reasoning
# Then there won't necessarily be anything else to come after the reasoning tokens
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
reasoning_start = False
# Flush any remaining content from citation processor
# Reasoning is always first so this should use the post-incremented value of turn_index
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
# as clickable items and will be stripped out instead.
if citation_processor:
yield from _emit_citation_results(citation_processor.process_token(None))
for result in citation_processor.process_token(None):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
@@ -1238,9 +1088,8 @@ def run_llm_step_pkt_generator(
reasoning=accumulated_reasoning if accumulated_reasoning else None,
answer=accumulated_answer if accumulated_answer else None,
tool_calls=tool_calls if tool_calls else None,
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
),
has_reasoned,
bool(has_reasoned),
)
@@ -1295,4 +1144,4 @@ def run_llm_step(
emitter.emit(packet)
except StopIteration as e:
llm_step_result, has_reasoned = e.value
return llm_step_result, has_reasoned
return llm_step_result, bool(has_reasoned)

View File

@@ -185,6 +185,3 @@ class LlmStepResult(BaseModel):
reasoning: str | None
answer: str | None
tool_calls: list[ToolCallKickoff] | None
# Raw LLM text before any display-oriented filtering/sanitization.
# Used for fallback tool-call extraction when providers emit calls as text.
raw_answer: str | None = None

View File

@@ -46,7 +46,6 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
@@ -157,7 +156,10 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
return False
# For shared drive content, the root has id == driveId
return bool(drive_id and folder_id == drive_id)
if drive_id and folder_id == drive_id:
return True
return False
def _public_access() -> ExternalAccess:
@@ -614,16 +616,6 @@ class GoogleDriveConnector(
# empty parents due to permission limitations)
# Check shared drive root first (simple ID comparison)
if _is_shared_drive_root(folder):
# files().get() returns 'Drive' for shared drive roots;
# fetch the real name via drives().get().
# Try both the retriever and admin since the admin may
# not have access to private shared drives.
drive_name = self._get_shared_drive_name(
current_id, file.user_email
)
if drive_name:
node.display_name = drive_name
node.node_type = HierarchyNodeType.SHARED_DRIVE
reached_terminal = True
break
@@ -699,15 +691,6 @@ class GoogleDriveConnector(
)
return None
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
"""Fetch the name of a shared drive, trying both the retriever and admin."""
for email in {retriever_email, self.primary_admin_email}:
svc = get_drive_service(self.creds, email)
name = get_shared_drive_name(svc, drive_id)
if name:
return name
return None
def get_all_drive_ids(self) -> set[str]:
return self._get_all_drives_for_user(self.primary_admin_email)

View File

@@ -154,26 +154,6 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return HIERARCHY_FIELDS
def get_shared_drive_name(
service: Resource,
drive_id: str,
) -> str | None:
"""Fetch the actual name of a shared drive via the drives().get() API.
The files().get() API returns 'Drive' as the name for shared drive root
folders. Only drives().get() returns the real user-assigned name.
"""
try:
drive = service.drives().get(driveId=drive_id, fields="name").execute()
return drive.get("name")
except HttpError as e:
if e.resp.status in (403, 404):
logger.debug(f"Cannot access drive {drive_id}: {e}")
else:
raise
return None
def get_external_access_for_folder(
folder: GoogleDriveFileType,
google_domain: str,

View File

@@ -430,7 +430,7 @@ def fetch_existing_models(
def fetch_existing_llm_providers(
db_session: Session,
flow_type_filter: list[LLMModelFlowType],
flow_types: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -438,27 +438,30 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_type_filter: List of flow types to filter by, empty list for no filter
flow_types: List of flow types to filter by
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
stmt = select(LLMProviderModel)
if flow_type_filter:
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
.distinct()
)
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
@@ -794,15 +797,13 @@ def sync_auto_mode_models(
changes += 1
else:
# Add new model - all models from GitHub config are visible
insert_new_model_configuration__no_commit(
db_session=db_session,
new_model = ModelConfiguration(
llm_provider_id=provider.id,
model_name=model_config.name,
supported_flows=[LLMModelFlowType.CHAT],
is_visible=True,
max_input_tokens=None,
name=model_config.name,
display_name=model_config.display_name,
is_visible=True,
)
db_session.add(new_model)
changes += 1
# In Auto mode, default model is always set from GitHub config

View File

@@ -21,6 +21,7 @@ from fastapi.routing import APIRoute
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from starlette.types import Lifespan
@@ -120,7 +121,6 @@ from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.pat.api import router as pat_router
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
from onyx.server.query_and_chat.chat_backend import router as chat_router
from onyx.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -563,8 +563,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app with production Prometheus config
setup_prometheus_metrics(application)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
use_route_function_names_as_operation_ids(application)

View File

@@ -102,9 +102,6 @@ def check_router_auth(
current_cloud_superuser = fetch_ee_implementation_or_noop(
"onyx.auth.users", "current_cloud_superuser"
)
verify_scim_token = fetch_ee_implementation_or_noop(
"onyx.server.scim.auth", "verify_scim_token"
)
for route in application.routes:
# explicitly marked as public
@@ -128,7 +125,6 @@ def check_router_auth(
or depends_fn == current_chat_accessible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
or depends_fn == verify_scim_token
):
found_auth = True
break

View File

@@ -8,7 +8,6 @@ import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
@@ -57,7 +56,7 @@ def is_version_gte(v1: str, v2: str) -> bool:
def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry]:
"""Parse MDX content into ReleaseNoteEntry objects."""
"""Parse MDX content into ReleaseNoteEntry objects for versions >= __version__."""
all_entries = []
update_pattern = (
@@ -83,12 +82,6 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
if not all_entries:
raise ValueError("Could not parse any release note entries from MDX.")
if INSTANCE_TYPE == "cloud":
# Cloud often runs ahead of docs release tags; always notify on latest release.
return sorted(
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
)[:1]
# Filter to valid versions >= __version__
if __version__ and is_valid_version(__version__):
entries = [

View File

@@ -310,7 +310,7 @@ def list_llm_providers(
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session=db_session,
flow_type_filter=[],
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
@@ -568,7 +568,9 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(db_session, [])
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN

View File

@@ -1,63 +0,0 @@
"""Prometheus instrumentation for the Onyx API server.
Provides a production-grade metrics configuration with:
- Exact HTTP status codes (no grouping into 2xx/3xx)
- In-progress request gauge broken down by handler and method
- Custom latency histogram buckets tuned for API workloads
- Request/response size tracking
- Slow request counter with configurable threshold
"""
import os
from prometheus_client import Counter
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import Info
from starlette.applications import Starlette
SLOW_REQUEST_THRESHOLD_SECONDS: float = float(
os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")
)
_EXCLUDED_HANDLERS = [
"/health",
"/metrics",
"/openapi.json",
]
_slow_requests = Counter(
"onyx_api_slow_requests_total",
"Total requests exceeding the slow request threshold",
["method", "handler", "status"],
)
def _slow_request_callback(info: Info) -> None:
"""Increment slow request counter when duration exceeds threshold."""
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
_slow_requests.labels(
method=info.method,
handler=info.modified_handler,
status=info.modified_status,
).inc()
def setup_prometheus_metrics(app: Starlette) -> None:
"""Configure and attach Prometheus instrumentation to the FastAPI app.
Records exact status codes, tracks in-progress requests per handler,
and counts slow requests exceeding a configurable threshold.
"""
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=False,
should_group_untemplated=True,
should_instrument_requests_inprogress=True,
inprogress_labels=True,
excluded_handlers=_EXCLUDED_HANDLERS,
)
instrumentator.add(_slow_request_callback)
instrumentator.instrument(app).expose(app)

View File

@@ -349,7 +349,6 @@ def get_chat_session(
shared_status=chat_session.shared_status,
current_temperature_override=chat_session.temperature_override,
deleted=chat_session.deleted,
owner_name=chat_session.user.personal_name if chat_session.user else None,
# Packets are now directly serialized as Packet Pydantic models
packets=replay_packet_lists,
)

View File

@@ -224,7 +224,6 @@ class ChatSessionDetailResponse(BaseModel):
current_alternate_model: str | None
current_temperature_override: float | None
deleted: bool = False
owner_name: str | None = None
packets: list[list[Packet]]

View File

@@ -2,7 +2,6 @@ import time
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import field
from dataclasses import replace
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
@@ -135,25 +134,25 @@ EXPECTED_SHARED_DRIVE_1_HIERARCHY = ExpectedHierarchyNode(
children=[
ExpectedHierarchyNode(
raw_node_id=RESTRICTED_ACCESS_FOLDER_ID,
display_name="restricted_access",
display_name="restricted_access_folder",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_1_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_1_ID,
display_name="folder 1",
display_name="folder_1",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_1_ID,
children=[
ExpectedHierarchyNode(
raw_node_id=FOLDER_1_1_ID,
display_name="folder 1-1",
display_name="folder_1_1",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_1_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_1_2_ID,
display_name="folder 1-2",
display_name="folder_1_2",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_1_ID,
),
@@ -171,25 +170,25 @@ EXPECTED_SHARED_DRIVE_2_HIERARCHY = ExpectedHierarchyNode(
children=[
ExpectedHierarchyNode(
raw_node_id=SECTIONS_FOLDER_ID,
display_name="sections",
display_name="sections_folder",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_2_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_2_ID,
display_name="folder 2",
display_name="folder_2",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=SHARED_DRIVE_2_ID,
children=[
ExpectedHierarchyNode(
raw_node_id=FOLDER_2_1_ID,
display_name="folder 2-1",
display_name="folder_2_1",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_2_ID,
),
ExpectedHierarchyNode(
raw_node_id=FOLDER_2_2_ID,
display_name="folder 2-2",
display_name="folder_2_2",
node_type=HierarchyNodeType.FOLDER,
raw_parent_id=FOLDER_2_ID,
),
@@ -209,23 +208,27 @@ def flatten_hierarchy(
return result
def _node(
raw_node_id: str,
display_name: str,
node_type: HierarchyNodeType,
raw_parent_id: str | None = None,
) -> ExpectedHierarchyNode:
return ExpectedHierarchyNode(
raw_node_id=raw_node_id,
display_name=display_name,
node_type=node_type,
raw_parent_id=raw_parent_id,
)
# Flattened maps for easy lookup
EXPECTED_SHARED_DRIVE_1_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_1_HIERARCHY)
EXPECTED_SHARED_DRIVE_2_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_2_HIERARCHY)
ALL_EXPECTED_SHARED_DRIVE_NODES = {
**EXPECTED_SHARED_DRIVE_1_NODES,
**EXPECTED_SHARED_DRIVE_2_NODES,
}
# Map of folder ID to its expected parent ID
EXPECTED_PARENT_MAPPING: dict[str, str | None] = {
SHARED_DRIVE_1_ID: None,
RESTRICTED_ACCESS_FOLDER_ID: SHARED_DRIVE_1_ID,
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
SHARED_DRIVE_2_ID: None,
SECTIONS_FOLDER_ID: SHARED_DRIVE_2_ID,
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
}
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
@@ -283,7 +286,7 @@ TEST_USER_1_MY_DRIVE_FOLDER_ID = (
)
TEST_USER_1_DRIVE_B_ID = (
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuper_private
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuuuuper_private
)
TEST_USER_1_DRIVE_B_FOLDER_ID = (
"1oIj7nigzvP5xI2F8BmibUA8R_J3AbBA-" # Child folder (silliness)
@@ -322,106 +325,6 @@ PERM_SYNC_DRIVE_ACCESS_MAPPING: dict[str, set[str]] = {
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: {ADMIN_EMAIL, TEST_USER_1_EMAIL},
}
# ============================================================================
# NON-SHARED-DRIVE HIERARCHY NODES
# ============================================================================
# These cover My Drive roots, perm sync drives, extra shared drives,
# and standalone folders that appear in various tests.
# Display names must match what the Google Drive API actually returns.
# ============================================================================
EXPECTED_FOLDER_3 = _node(
FOLDER_3_ID, "Folder 3", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
)
EXPECTED_ADMIN_MY_DRIVE = _node(ADMIN_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER)
EXPECTED_TEST_USER_1_MY_DRIVE = _node(
TEST_USER_1_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
)
EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER = _node(
TEST_USER_1_MY_DRIVE_FOLDER_ID,
"partial_sharing",
HierarchyNodeType.FOLDER,
TEST_USER_1_MY_DRIVE_ID,
)
EXPECTED_TEST_USER_2_MY_DRIVE = _node(
TEST_USER_2_MY_DRIVE, "My Drive", HierarchyNodeType.FOLDER
)
EXPECTED_TEST_USER_3_MY_DRIVE = _node(
TEST_USER_3_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
)
EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY = _node(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
"perm_sync_drive_0dc9d8b5-e243-4c2f-8678-2235958f7d7c",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A = _node(
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
"perm_sync_drive_785db121-0823-4ebe-8689-ad7f52405e32",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B = _node(
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
"perm_sync_drive_d8dc3649-3f65-4392-b87f-4b20e0389673",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_TEST_USER_1_DRIVE_B = _node(
TEST_USER_1_DRIVE_B_ID,
"My_super_special_shared_drive_suuuper_private",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_TEST_USER_1_DRIVE_B_FOLDER = _node(
TEST_USER_1_DRIVE_B_FOLDER_ID,
"silliness",
HierarchyNodeType.FOLDER,
TEST_USER_1_DRIVE_B_ID,
)
EXPECTED_TEST_USER_1_EXTRA_DRIVE_1 = _node(
TEST_USER_1_EXTRA_DRIVE_1_ID,
"Okay_Admin_fine_I_will_share",
HierarchyNodeType.SHARED_DRIVE,
)
EXPECTED_TEST_USER_1_EXTRA_DRIVE_2 = _node(
TEST_USER_1_EXTRA_DRIVE_2_ID, "reee test", HierarchyNodeType.SHARED_DRIVE
)
EXPECTED_TEST_USER_1_EXTRA_FOLDER = _node(
TEST_USER_1_EXTRA_FOLDER_ID,
"read only no download test",
HierarchyNodeType.FOLDER,
)
EXPECTED_PILL_FOLDER = _node(
PILL_FOLDER_ID, "pill_folder", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
)
EXPECTED_EXTERNAL_SHARED_FOLDER = _node(
EXTERNAL_SHARED_FOLDER_ID, "Onyx-test", HierarchyNodeType.FOLDER
)
# Comprehensive mapping of ALL known hierarchy nodes.
# Every retrieved node is checked against this for display_name and node_type.
ALL_EXPECTED_HIERARCHY_NODES: dict[str, ExpectedHierarchyNode] = {
**EXPECTED_SHARED_DRIVE_1_NODES,
**EXPECTED_SHARED_DRIVE_2_NODES,
FOLDER_3_ID: EXPECTED_FOLDER_3,
ADMIN_MY_DRIVE_ID: EXPECTED_ADMIN_MY_DRIVE,
TEST_USER_1_MY_DRIVE_ID: EXPECTED_TEST_USER_1_MY_DRIVE,
TEST_USER_1_MY_DRIVE_FOLDER_ID: EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER,
TEST_USER_2_MY_DRIVE: EXPECTED_TEST_USER_2_MY_DRIVE,
TEST_USER_3_MY_DRIVE_ID: EXPECTED_TEST_USER_3_MY_DRIVE,
PERM_SYNC_DRIVE_ADMIN_ONLY_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B,
TEST_USER_1_DRIVE_B_ID: EXPECTED_TEST_USER_1_DRIVE_B,
TEST_USER_1_DRIVE_B_FOLDER_ID: EXPECTED_TEST_USER_1_DRIVE_B_FOLDER,
TEST_USER_1_EXTRA_DRIVE_1_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_1,
TEST_USER_1_EXTRA_DRIVE_2_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_2,
TEST_USER_1_EXTRA_FOLDER_ID: EXPECTED_TEST_USER_1_EXTRA_FOLDER,
PILL_FOLDER_ID: EXPECTED_PILL_FOLDER,
EXTERNAL_SHARED_FOLDER_ID: EXPECTED_EXTERNAL_SHARED_FOLDER,
}
# Dictionary for access permissions
# All users have access to their own My Drive as well as public files
ACCESS_MAPPING: dict[str, list[int]] = {
@@ -605,29 +508,28 @@ def load_connector_outputs(
def assert_hierarchy_nodes_match_expected(
retrieved_nodes: list[HierarchyNode],
expected_nodes: dict[str, ExpectedHierarchyNode],
expected_node_ids: set[str],
expected_parent_mapping: dict[str, str | None] | None = None,
ignorable_node_ids: set[str] | None = None,
) -> None:
"""
Assert that retrieved hierarchy nodes match expected structure.
Checks node IDs, display names, node types, and parent relationships
for EVERY retrieved node (global checks).
Args:
retrieved_nodes: List of HierarchyNode objects from the connector
expected_nodes: Dict mapping raw_node_id -> ExpectedHierarchyNode with
expected display_name, node_type, and raw_parent_id
ignorable_node_ids: Optional set of node IDs that can be missing or extra
without failing. Useful for non-deterministically returned nodes.
expected_node_ids: Set of expected raw_node_ids
expected_parent_mapping: Optional dict mapping node_id -> parent_id for parent verification
ignorable_node_ids: Optional set of node IDs that can be missing or extra without failing.
Useful for nodes that are non-deterministically returned by the connector.
"""
expected_node_ids = set(expected_nodes.keys())
retrieved_node_ids = {node.raw_node_id for node in retrieved_nodes}
ignorable = ignorable_node_ids or set()
# Calculate differences, excluding ignorable nodes
missing = expected_node_ids - retrieved_node_ids - ignorable
extra = retrieved_node_ids - expected_node_ids - ignorable
# Print discrepancies for debugging
if missing or extra:
print("Expected hierarchy node IDs:")
print(sorted(expected_node_ids))
@@ -641,146 +543,181 @@ def assert_hierarchy_nodes_match_expected(
print("Ignorable node IDs:")
print(sorted(ignorable))
assert (
not missing and not extra
), f"Hierarchy node mismatch. Missing: {missing}, Extra: {extra}"
assert not missing and not extra, (
f"Hierarchy node mismatch. " f"Missing: {missing}, " f"Extra: {extra}"
)
for node in retrieved_nodes:
if node.raw_node_id in ignorable and node.raw_node_id not in expected_nodes:
continue
assert (
node.raw_node_id in expected_nodes
), f"Node {node.raw_node_id} ({node.display_name}) not found in expected_nodes"
expected = expected_nodes[node.raw_node_id]
assert node.display_name == expected.display_name, (
f"Display name mismatch for node {node.raw_node_id}: "
f"expected '{expected.display_name}', got '{node.display_name}'"
)
assert node.node_type == expected.node_type, (
f"Node type mismatch for node {node.raw_node_id}: "
f"expected '{expected.node_type}', got '{node.node_type}'"
)
if expected.raw_parent_id is not None:
assert node.raw_parent_id == expected.raw_parent_id, (
# Verify parent relationships if provided
if expected_parent_mapping is not None:
for node in retrieved_nodes:
if node.raw_node_id not in expected_parent_mapping:
continue
expected_parent = expected_parent_mapping[node.raw_node_id]
assert node.raw_parent_id == expected_parent, (
f"Parent mismatch for node {node.raw_node_id} ({node.display_name}): "
f"expected parent={expected.raw_parent_id}, got parent={node.raw_parent_id}"
f"expected parent={expected_parent}, got parent={node.raw_parent_id}"
)
def _pick(
*node_ids: str,
) -> dict[str, ExpectedHierarchyNode]:
"""Pick nodes from ALL_EXPECTED_HIERARCHY_NODES by their IDs."""
return {nid: ALL_EXPECTED_HIERARCHY_NODES[nid] for nid in node_ids}
def _clear_parents(
nodes: dict[str, ExpectedHierarchyNode],
*node_ids: str,
) -> dict[str, ExpectedHierarchyNode]:
"""Return a shallow copy of nodes with the specified nodes' parents set to None.
Useful for OAuth tests where the user can't resolve certain parents
(e.g. a folder in another user's My Drive)."""
result = dict(nodes)
for nid in node_ids:
result[nid] = replace(result[nid], raw_parent_id=None)
return result
def get_expected_hierarchy_for_shared_drives(
include_drive_1: bool = True,
include_drive_2: bool = True,
include_restricted_folder: bool = True,
) -> dict[str, ExpectedHierarchyNode]:
"""Get expected hierarchy nodes for shared drives."""
result: dict[str, ExpectedHierarchyNode] = {}
) -> tuple[set[str], dict[str, str | None]]:
"""
Get expected hierarchy node IDs and parent mapping for shared drives.
Returns:
Tuple of (expected_node_ids, expected_parent_mapping)
"""
expected_ids: set[str] = set()
expected_parents: dict[str, str | None] = {}
if include_drive_1:
result.update(EXPECTED_SHARED_DRIVE_1_NODES)
if not include_restricted_folder:
result.pop(RESTRICTED_ACCESS_FOLDER_ID, None)
expected_ids.add(SHARED_DRIVE_1_ID)
expected_parents[SHARED_DRIVE_1_ID] = None
if include_restricted_folder:
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_parents[RESTRICTED_ACCESS_FOLDER_ID] = SHARED_DRIVE_1_ID
expected_ids.add(FOLDER_1_ID)
expected_parents[FOLDER_1_ID] = SHARED_DRIVE_1_ID
expected_ids.add(FOLDER_1_1_ID)
expected_parents[FOLDER_1_1_ID] = FOLDER_1_ID
expected_ids.add(FOLDER_1_2_ID)
expected_parents[FOLDER_1_2_ID] = FOLDER_1_ID
if include_drive_2:
result.update(EXPECTED_SHARED_DRIVE_2_NODES)
expected_ids.add(SHARED_DRIVE_2_ID)
expected_parents[SHARED_DRIVE_2_ID] = None
return result
expected_ids.add(SECTIONS_FOLDER_ID)
expected_parents[SECTIONS_FOLDER_ID] = SHARED_DRIVE_2_ID
expected_ids.add(FOLDER_2_ID)
expected_parents[FOLDER_2_ID] = SHARED_DRIVE_2_ID
expected_ids.add(FOLDER_2_1_ID)
expected_parents[FOLDER_2_1_ID] = FOLDER_2_ID
expected_ids.add(FOLDER_2_2_ID)
expected_parents[FOLDER_2_2_ID] = FOLDER_2_ID
return expected_ids, expected_parents
def get_expected_hierarchy_for_folder_1() -> dict[str, ExpectedHierarchyNode]:
def get_expected_hierarchy_for_folder_1() -> tuple[set[str], dict[str, str | None]]:
"""Get expected hierarchy for folder_1 and its children only."""
return _pick(FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID)
return (
{FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID},
{
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
},
)
def get_expected_hierarchy_for_folder_2() -> dict[str, ExpectedHierarchyNode]:
def get_expected_hierarchy_for_folder_2() -> tuple[set[str], dict[str, str | None]]:
"""Get expected hierarchy for folder_2 and its children only."""
return _pick(FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID)
return (
{FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID},
{
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
},
)
def get_expected_hierarchy_for_test_user_1() -> dict[str, ExpectedHierarchyNode]:
def get_expected_hierarchy_for_test_user_1() -> tuple[set[str], dict[str, str | None]]:
"""
Get expected hierarchy for test_user_1's full access (OAuth).
Get expected hierarchy for test_user_1's full access.
test_user_1 has access to:
- shared_drive_1 and its contents (folder_1, folder_1_1, folder_1_2)
- folder_3 (shared from admin's My Drive)
- PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A and PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B
- Additional drives/folders the user has access to
NOTE: Folder 3 lives in the admin's My Drive. When running as an OAuth
connector for test_user_1, the Google Drive API won't return the parent
for Folder 3 because the user can't access the admin's My Drive root.
"""
result = get_expected_hierarchy_for_shared_drives(
# Start with shared_drive_1 hierarchy
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=False,
include_restricted_folder=False,
)
result.update(
_pick(
FOLDER_3_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
TEST_USER_1_DRIVE_B_ID,
TEST_USER_1_DRIVE_B_FOLDER_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
)
)
return _clear_parents(result, FOLDER_3_ID)
# folder_3 is shared from admin's My Drive
expected_ids.add(FOLDER_3_ID)
# Perm sync drives that test_user_1 has access to
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID] = None
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID] = None
# Additional drives/folders test_user_1 has access to
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
expected_parents[TEST_USER_1_MY_DRIVE_ID] = None
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
expected_parents[TEST_USER_1_MY_DRIVE_FOLDER_ID] = TEST_USER_1_MY_DRIVE_ID
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
expected_parents[TEST_USER_1_DRIVE_B_ID] = None
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
expected_parents[TEST_USER_1_DRIVE_B_FOLDER_ID] = TEST_USER_1_DRIVE_B_ID
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_parents[TEST_USER_1_EXTRA_DRIVE_1_ID] = None
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_parents[TEST_USER_1_EXTRA_DRIVE_2_ID] = None
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
# Parent unknown, skip adding to expected_parents
return expected_ids, expected_parents
def get_expected_hierarchy_for_test_user_1_shared_drives_only() -> (
dict[str, ExpectedHierarchyNode]
tuple[set[str], dict[str, str | None]]
):
"""Expected hierarchy nodes when test_user_1 runs with include_shared_drives=True only."""
result = get_expected_hierarchy_for_test_user_1()
for nid in (
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
FOLDER_3_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
):
result.pop(nid, None)
return result
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
# This mode should not include My Drive roots/folders.
expected_ids.discard(TEST_USER_1_MY_DRIVE_ID)
expected_ids.discard(TEST_USER_1_MY_DRIVE_FOLDER_ID)
# don't include shared with me
expected_ids.discard(FOLDER_3_ID)
expected_ids.discard(TEST_USER_1_EXTRA_FOLDER_ID)
return expected_ids, expected_parents
def get_expected_hierarchy_for_test_user_1_shared_with_me_only() -> (
dict[str, ExpectedHierarchyNode]
tuple[set[str], dict[str, str | None]]
):
"""Expected hierarchy nodes when test_user_1 runs with include_files_shared_with_me=True only."""
return _clear_parents(
_pick(FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID),
FOLDER_3_ID,
)
expected_ids: set[str] = {FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID}
expected_parents: dict[str, str | None] = {}
return expected_ids, expected_parents
def get_expected_hierarchy_for_test_user_1_my_drive_only() -> (
dict[str, ExpectedHierarchyNode]
tuple[set[str], dict[str, str | None]]
):
"""Expected hierarchy nodes when test_user_1 runs with include_my_drives=True only."""
return _pick(TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID)
expected_ids: set[str] = {TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID}
expected_parents: dict[str, str | None] = {
TEST_USER_1_MY_DRIVE_ID: None,
TEST_USER_1_MY_DRIVE_FOLDER_ID: TEST_USER_1_MY_DRIVE_ID,
}
return expected_ids, expected_parents

View File

@@ -3,11 +3,12 @@ from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
ADMIN_MY_DRIVE_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
@@ -15,15 +16,21 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
assert_hierarchy_nodes_match_expected,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
@@ -40,15 +47,18 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
from tests.daily.connectors.google_drive.consts_and_utils import (
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import PILL_FOLDER_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
PILL_FOLDER_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
RESTRICTED_ACCESS_FOLDER_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
TEST_USER_1_EXTRA_DRIVE_1_ID,
)
@@ -80,6 +90,7 @@ def test_include_all(
)
output = load_connector_outputs(connector)
# Should get everything in shared and admin's My Drive with oauth
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
@@ -98,28 +109,33 @@ def test_include_all(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify hierarchy nodes for shared drives
# When include_shared_drives=True, we get ALL shared drives the admin has access to
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
# Restricted folder may not always be retrieved due to access limitations
include_restricted_folder=False,
)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
ADMIN_MY_DRIVE_ID,
PILL_FOLDER_ID,
RESTRICTED_ACCESS_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
FOLDER_3_ID,
)
)
# Add additional shared drives that admin has access to
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(ADMIN_MY_DRIVE_ID)
expected_ids.add(PILL_FOLDER_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
# My Drive folders
expected_ids.add(FOLDER_3_ID)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -144,6 +160,7 @@ def test_include_shared_drives_only(
)
output = load_connector_outputs(connector)
# Should only get shared drives
expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
@@ -160,24 +177,26 @@ def test_include_shared_drives_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify hierarchy nodes - should include both shared drives and their folders
# When include_shared_drives=True, we get ALL shared drives admin has access to
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
RESTRICTED_ACCESS_FOLDER_ID,
)
)
# Add additional shared drives that admin has access to
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
)
@@ -201,21 +220,24 @@ def test_include_my_drives_only(
)
output = load_connector_outputs(connector)
# Should only get primary_admins My Drive because we are impersonating them
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
assert_expected_docs_in_retrieved_docs(
retrieved_docs=output.documents,
expected_file_ids=expected_file_ids,
)
expected_nodes = _pick(
# Verify hierarchy nodes - My Drive should yield folder_3 as a hierarchy node
# Also includes admin's My Drive root and folders shared with admin
expected_ids = {
FOLDER_3_ID,
ADMIN_MY_DRIVE_ID,
PILL_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
)
}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
)
@@ -251,14 +273,17 @@ def test_drive_one_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=False,
include_restricted_folder=False,
)
# Restricted folder is non-deterministically returned by the connector
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -299,15 +324,33 @@ def test_folder_and_shared_drive(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
expected_ids = {
SHARED_DRIVE_1_ID,
FOLDER_1_ID,
FOLDER_1_1_ID,
FOLDER_1_2_ID,
SHARED_DRIVE_2_ID,
FOLDER_2_ID,
FOLDER_2_1_ID,
FOLDER_2_2_ID,
}
expected_parents = {
SHARED_DRIVE_1_ID: None,
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
SHARED_DRIVE_2_ID: None,
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
}
# Restricted folder is non-deterministically returned
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -327,6 +370,7 @@ def test_folders_only(
FOLDER_2_2_URL,
FOLDER_3_URL,
]
# This should get converted to a drive request and spit out a warning in the logs
shared_drive_urls = [
FOLDER_1_1_URL,
]
@@ -353,16 +397,23 @@ def test_folders_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
# Verify hierarchy nodes - specific folders requested plus their parent nodes
# The connector walks up the hierarchy to include parent drives/folders
expected_ids = {
SHARED_DRIVE_1_ID,
FOLDER_1_ID,
FOLDER_1_1_ID,
FOLDER_1_2_ID,
SHARED_DRIVE_2_ID,
FOLDER_2_ID,
FOLDER_2_1_ID,
FOLDER_2_2_ID,
ADMIN_MY_DRIVE_ID,
FOLDER_3_ID,
}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
)
@@ -395,8 +446,9 @@ def test_personal_folders_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = _pick(FOLDER_3_ID, ADMIN_MY_DRIVE_ID)
# Verify hierarchy nodes - folder_3 and its parent (admin's My Drive root)
expected_ids = {FOLDER_3_ID, ADMIN_MY_DRIVE_ID}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
)

View File

@@ -14,10 +14,11 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
ADMIN_MY_DRIVE_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_hierarchy_nodes_match_expected,
)
@@ -261,35 +262,37 @@ def test_gdrive_perm_sync_with_real_data(
hierarchy_connector = _build_connector(google_drive_service_acct_connector_factory)
output = load_connector_outputs(hierarchy_connector, include_permissions=True)
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify the expected shared drives hierarchy
# When include_shared_drives=True and include_my_drives=True, we get ALL drives
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
TEST_USER_1_DRIVE_B_ID,
TEST_USER_1_DRIVE_B_FOLDER_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
ADMIN_MY_DRIVE_ID,
TEST_USER_2_MY_DRIVE,
TEST_USER_3_MY_DRIVE_ID,
PILL_FOLDER_ID,
RESTRICTED_ACCESS_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
EXTERNAL_SHARED_FOLDER_ID,
FOLDER_3_ID,
)
)
# Add additional shared drives in the organization
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(ADMIN_MY_DRIVE_ID)
expected_ids.add(TEST_USER_2_MY_DRIVE)
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
expected_ids.add(PILL_FOLDER_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
expected_ids.add(FOLDER_3_ID)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)

View File

@@ -4,11 +4,12 @@ from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
ADMIN_MY_DRIVE_ID,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
@@ -28,15 +29,21 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
@@ -67,10 +74,11 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
RESTRICTED_ACCESS_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
from tests.daily.connectors.google_drive.consts_and_utils import (
TEST_USER_1_DRIVE_B_FOLDER_ID,
)
@@ -148,35 +156,39 @@ def test_include_all(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify hierarchy nodes for shared drives
# When include_shared_drives=True, we get ALL shared drives in the organization
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_FOLDER_ID,
TEST_USER_1_DRIVE_B_ID,
TEST_USER_1_DRIVE_B_FOLDER_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
ADMIN_MY_DRIVE_ID,
TEST_USER_2_MY_DRIVE,
TEST_USER_3_MY_DRIVE_ID,
PILL_FOLDER_ID,
RESTRICTED_ACCESS_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
EXTERNAL_SHARED_FOLDER_ID,
FOLDER_3_ID,
)
)
# Add additional shared drives in the organization
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(ADMIN_MY_DRIVE_ID)
expected_ids.add(TEST_USER_2_MY_DRIVE)
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
expected_ids.add(PILL_FOLDER_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
# My Drive folders
expected_ids.add(FOLDER_3_ID)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -282,26 +294,28 @@ def test_include_shared_drives_only(
# TODO: switch to 54 when restricted access issue is resolved
assert len(output.documents) == 51 or len(output.documents) == 52
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify hierarchy nodes - should include both shared drives and their folders
# When include_shared_drives=True, we get ALL shared drives in the organization
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.update(
_pick(
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
TEST_USER_1_DRIVE_B_ID,
TEST_USER_1_DRIVE_B_FOLDER_ID,
TEST_USER_1_EXTRA_DRIVE_1_ID,
TEST_USER_1_EXTRA_DRIVE_2_ID,
RESTRICTED_ACCESS_FOLDER_ID,
)
)
# Add additional shared drives in the organization
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -339,7 +353,9 @@ def test_include_my_drives_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = _pick(
# Verify hierarchy nodes - My Drive roots and folders for all users
# Service account impersonates all users, so it sees all My Drives
expected_ids = {
FOLDER_3_ID,
ADMIN_MY_DRIVE_ID,
TEST_USER_1_MY_DRIVE_ID,
@@ -349,10 +365,10 @@ def test_include_my_drives_only(
PILL_FOLDER_ID,
TEST_USER_1_EXTRA_FOLDER_ID,
EXTERNAL_SHARED_FOLDER_ID,
)
}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
)
@@ -389,14 +405,17 @@ def test_drive_one_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=False,
include_restricted_folder=False,
)
# Restricted folder is non-deterministically returned
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -438,15 +457,33 @@ def test_folder_and_shared_drive(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
expected_ids = {
SHARED_DRIVE_1_ID,
FOLDER_1_ID,
FOLDER_1_1_ID,
FOLDER_1_2_ID,
SHARED_DRIVE_2_ID,
FOLDER_2_ID,
FOLDER_2_1_ID,
FOLDER_2_2_ID,
}
expected_parents = {
SHARED_DRIVE_1_ID: None,
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
SHARED_DRIVE_2_ID: None,
FOLDER_2_ID: SHARED_DRIVE_2_ID,
FOLDER_2_1_ID: FOLDER_2_ID,
FOLDER_2_2_ID: FOLDER_2_ID,
}
# Restricted folder is non-deterministically returned
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
)
@@ -493,16 +530,23 @@ def test_folders_only(
expected_file_ids=expected_file_ids,
)
expected_nodes = get_expected_hierarchy_for_shared_drives(
include_drive_1=True,
include_drive_2=True,
include_restricted_folder=False,
)
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
# Verify hierarchy nodes - specific folders requested plus their parent nodes
# The connector walks up the hierarchy to include parent drives/folders
expected_ids = {
SHARED_DRIVE_1_ID,
FOLDER_1_ID,
FOLDER_1_1_ID,
FOLDER_1_2_ID,
SHARED_DRIVE_2_ID,
FOLDER_2_ID,
FOLDER_2_1_ID,
FOLDER_2_2_ID,
ADMIN_MY_DRIVE_ID,
FOLDER_3_ID,
}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=expected_nodes,
expected_node_ids=expected_ids,
)

View File

@@ -4,8 +4,6 @@ from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import _clear_parents
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
@@ -53,6 +51,8 @@ def _check_for_error(
retrieved_failures = output.failures
assert len(retrieved_failures) <= 1
# current behavior is to fail silently for 403s; leaving this here for when we revert
# if all 403s get fixed
if len(retrieved_failures) == 1:
fail_msg = retrieved_failures[0].failure_message
assert "HttpError 403" in fail_msg
@@ -83,11 +83,14 @@ def test_all(
output = load_connector_outputs(connector)
expected_file_ids = (
# These are the files from my drive
TEST_USER_1_FILE_IDS
# These are the files from shared drives
+ SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
# These are the files shared with me from admin
+ ADMIN_FOLDER_3_FILE_IDS
+ list(range(0, 2))
)
@@ -99,9 +102,13 @@ def test_all(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - test_user_1 has access to shared_drive_1, folder_3,
# perm sync drives, and additional drives/folders
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=get_expected_hierarchy_for_test_user_1(),
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
)
@@ -126,6 +133,7 @@ def test_shared_drives_only(
output = load_connector_outputs(connector)
expected_file_ids = (
# These are the files from shared drives
SHARED_DRIVE_1_FILE_IDS
+ FOLDER_1_FILE_IDS
+ FOLDER_1_1_FILE_IDS
@@ -138,9 +146,14 @@ def test_shared_drives_only(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - test_user_1 sees multiple shared drives/folders
expected_ids, expected_parents = (
get_expected_hierarchy_for_test_user_1_shared_drives_only()
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=get_expected_hierarchy_for_test_user_1_shared_drives_only(),
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
)
@@ -164,15 +177,24 @@ def test_shared_with_me_only(
)
output = load_connector_outputs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS + list(range(0, 2))
expected_file_ids = (
# These are the files shared with me from admin
ADMIN_FOLDER_3_FILE_IDS
+ list(range(0, 2))
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=output.documents,
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - shared-with-me folders
expected_ids, expected_parents = (
get_expected_hierarchy_for_test_user_1_shared_with_me_only()
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=get_expected_hierarchy_for_test_user_1_shared_with_me_only(),
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
)
@@ -196,15 +218,21 @@ def test_my_drive_only(
)
output = load_connector_outputs(connector)
# These are the files from my drive
expected_file_ids = TEST_USER_1_FILE_IDS
assert_expected_docs_in_retrieved_docs(
retrieved_docs=output.documents,
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - My Drive root + its folder(s)
expected_ids, expected_parents = (
get_expected_hierarchy_for_test_user_1_my_drive_only()
)
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=get_expected_hierarchy_for_test_user_1_my_drive_only(),
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
)
@@ -228,15 +256,20 @@ def test_shared_my_drive_folder(
)
output = load_connector_outputs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
expected_file_ids = (
# this is a folder from admin's drive that is shared with me
ADMIN_FOLDER_3_FILE_IDS
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=output.documents,
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - only folder_3
expected_ids = {FOLDER_3_ID}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=_clear_parents(_pick(FOLDER_3_ID), FOLDER_3_ID),
expected_node_ids=expected_ids,
)
@@ -266,9 +299,16 @@ def test_shared_drive_folder(
expected_file_ids=expected_file_ids,
)
# Verify hierarchy nodes - includes shared drive root + folder_1 subtree
expected_ids = {SHARED_DRIVE_1_ID, FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID}
expected_parents: dict[str, str | None] = {
SHARED_DRIVE_1_ID: None,
FOLDER_1_ID: SHARED_DRIVE_1_ID,
FOLDER_1_1_ID: FOLDER_1_ID,
FOLDER_1_2_ID: FOLDER_1_ID,
}
assert_hierarchy_nodes_match_expected(
retrieved_nodes=output.hierarchy_nodes,
expected_nodes=_pick(
SHARED_DRIVE_1_ID, FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID
),
expected_node_ids=expected_ids,
expected_parent_mapping=expected_parents,
)

View File

@@ -553,7 +553,7 @@ class TestDefaultProviderEndpoint:
try:
existing_providers = fetch_existing_llm_providers(
db_session, flow_type_filter=[LLMModelFlowType.CHAT]
db_session, flow_types=[LLMModelFlowType.CHAT]
)
provider_names_to_restore: list[str] = []

View File

@@ -14,12 +14,9 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
from onyx.db.models import UserRole
from onyx.llm.constants import LlmProviderNames
@@ -609,95 +606,3 @@ class TestAutoModeSyncFeature:
db_session.rollback()
_cleanup_provider(db_session, provider_1_name)
_cleanup_provider(db_session, provider_2_name)
class TestAutoModeMissingFlows:
"""Regression test: sync_auto_mode_models must create LLMModelFlow rows
for every ModelConfiguration it inserts, otherwise the provider vanishes
from listing queries that join through LLMModelFlow."""
def test_sync_auto_mode_creates_flow_rows(
self,
db_session: Session,
provider_name: str,
) -> None:
"""
Steps:
1. Create a provider with no model configs (empty shell).
2. Call sync_auto_mode_models to add models from a mock config.
3. Assert every new ModelConfiguration has at least one LLMModelFlow.
4. Assert fetch_existing_llm_providers (which joins through
LLMModelFlow) returns the provider.
"""
mock_recommendations = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create provider with no model configs
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Run sync_auto_mode_models (simulating the periodic sync)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=mock_recommendations,
)
# Step 3: Every ModelConfiguration must have at least one LLMModelFlow
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
synced_model_names = {mc.name for mc in provider.model_configurations}
assert "gpt-4o" in synced_model_names
assert "gpt-4o-mini" in synced_model_names
for mc in provider.model_configurations:
assert len(mc.llm_model_flows) > 0, (
f"ModelConfiguration '{mc.name}' (id={mc.id}) has no "
f"LLMModelFlow rows — it will be invisible to listing queries"
)
flow_types = {f.llm_model_flow_type for f in mc.llm_model_flows}
assert (
LLMModelFlowType.CHAT in flow_types
), f"ModelConfiguration '{mc.name}' is missing a CHAT flow"
# Step 4: The provider must appear in fetch_existing_llm_providers
listed_providers = fetch_existing_llm_providers(
db_session=db_session,
flow_type_filter=[LLMModelFlowType.CHAT],
)
listed_provider_names = {p.name for p in listed_providers}
assert provider_name in listed_provider_names, (
f"Provider '{provider_name}' not returned by "
f"fetch_existing_llm_providers — models are missing flow rows"
)
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -996,114 +996,6 @@ class TestFallbackToolExtraction:
assert result.tool_calls[0].tool_args == {"queries": ["beta"]}
assert result.tool_calls[0].placement == Placement(turn_index=5)
def test_extracts_xml_style_invoke_from_answer_when_required(self) -> None:
llm_step_result = LlmStepResult(
reasoning=None,
answer=(
'<function_calls><invoke name="internal_search">'
'<parameter name="queries" string="false">'
'["Onyx documentation", "Onyx docs", "Onyx platform"]'
"</parameter></invoke></function_calls>"
),
tool_calls=None,
)
result, attempted = _try_fallback_tool_extraction(
llm_step_result=llm_step_result,
tool_choice=ToolChoiceOptions.REQUIRED,
fallback_extraction_attempted=False,
tool_defs=self._tool_defs(),
turn_index=7,
)
assert attempted is True
assert result.tool_calls is not None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].tool_name == "internal_search"
assert result.tool_calls[0].tool_args == {
"queries": ["Onyx documentation", "Onyx docs", "Onyx platform"]
}
assert result.tool_calls[0].placement == Placement(turn_index=7)
def test_extracts_xml_style_invoke_from_answer_when_auto(self) -> None:
llm_step_result = LlmStepResult(
reasoning=None,
# Runtime-faithful shape: filtered answer is empty, raw answer has XML payload.
answer=None,
raw_answer=(
'<function_calls><invoke name="internal_search">'
'<parameter name="queries" string="false">'
'["Onyx documentation", "Onyx docs", "Onyx internal docs"]'
"</parameter></invoke></function_calls>"
),
tool_calls=None,
)
result, attempted = _try_fallback_tool_extraction(
llm_step_result=llm_step_result,
tool_choice=ToolChoiceOptions.AUTO,
fallback_extraction_attempted=False,
tool_defs=self._tool_defs(),
turn_index=9,
)
assert attempted is True
assert result.tool_calls is not None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].tool_name == "internal_search"
assert result.tool_calls[0].tool_args == {
"queries": ["Onyx documentation", "Onyx docs", "Onyx internal docs"]
}
assert result.tool_calls[0].placement == Placement(turn_index=9)
def test_extracts_from_raw_answer_when_filtered_answer_has_no_xml(self) -> None:
llm_step_result = LlmStepResult(
reasoning=None,
answer="",
raw_answer=(
'<function_calls><invoke name="internal_search">'
'<parameter name="queries" string="false">'
'["Onyx documentation", "Onyx docs"]'
"</parameter></invoke></function_calls>"
),
tool_calls=None,
)
result, attempted = _try_fallback_tool_extraction(
llm_step_result=llm_step_result,
tool_choice=ToolChoiceOptions.AUTO,
fallback_extraction_attempted=False,
tool_defs=self._tool_defs(),
turn_index=10,
)
assert attempted is True
assert result.tool_calls is not None
assert len(result.tool_calls) == 1
assert result.tool_calls[0].tool_name == "internal_search"
assert result.tool_calls[0].tool_args == {
"queries": ["Onyx documentation", "Onyx docs"]
}
assert result.tool_calls[0].placement == Placement(turn_index=10)
def test_does_not_attempt_fallback_for_auto_without_tool_call_hints(self) -> None:
llm_step_result = LlmStepResult(
reasoning=None,
answer="Here is a normal answer with no tool call payload.",
tool_calls=None,
)
result, attempted = _try_fallback_tool_extraction(
llm_step_result=llm_step_result,
tool_choice=ToolChoiceOptions.AUTO,
fallback_extraction_attempted=False,
tool_defs=self._tool_defs(),
turn_index=2,
)
assert result is llm_step_result
assert attempted is False
def test_returns_unchanged_when_required_but_nothing_extractable(self) -> None:
llm_step_result = LlmStepResult(
reasoning="Need more info.",

View File

@@ -1,13 +1,7 @@
"""Tests for llm_step.py, specifically sanitization and argument parsing."""
from typing import Any
from onyx.chat.llm_step import _extract_tool_call_kickoffs
from onyx.chat.llm_step import _increment_turns
from onyx.chat.llm_step import _parse_tool_args_to_dict
from onyx.chat.llm_step import _resolve_tool_arguments
from onyx.chat.llm_step import _sanitize_llm_output
from onyx.chat.llm_step import _XmlToolCallContentFilter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.server.query_and_chat.placement import Placement
@@ -217,204 +211,3 @@ class TestExtractToolCallsFromResponseText:
{"queries": ["alpha"]},
{"queries": ["alpha"]},
]
def test_extracts_xml_style_invoke_tool_call(self) -> None:
response_text = """
<function_calls>
<invoke name="internal_search">
<parameter name="queries" string="false">["Onyx documentation", "Onyx docs", "Onyx platform"]</parameter>
</invoke>
</function_calls>
"""
tool_calls = extract_tool_calls_from_response_text(
response_text=response_text,
tool_definitions=self._tool_defs(),
placement=self._placement(),
)
assert len(tool_calls) == 1
assert tool_calls[0].tool_name == "internal_search"
assert tool_calls[0].tool_args == {
"queries": ["Onyx documentation", "Onyx docs", "Onyx platform"]
}
def test_ignores_unknown_tool_in_xml_style_invoke(self) -> None:
response_text = """
<function_calls>
<invoke name="unknown_tool">
<parameter name="queries" string="false">["Onyx docs"]</parameter>
</invoke>
</function_calls>
"""
tool_calls = extract_tool_calls_from_response_text(
response_text=response_text,
tool_definitions=self._tool_defs(),
placement=self._placement(),
)
assert len(tool_calls) == 0
class TestExtractToolCallKickoffs:
"""Tests for the _extract_tool_call_kickoffs function."""
def test_valid_tool_call(self) -> None:
tool_call_map = {
0: {
"id": "call_123",
"name": "internal_search",
"arguments": '{"queries": ["test"]}',
}
}
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
assert len(result) == 1
assert result[0].tool_name == "internal_search"
assert result[0].tool_args == {"queries": ["test"]}
def test_invalid_json_arguments_returns_empty_dict(self) -> None:
"""Verify that malformed JSON arguments produce an empty dict
rather than raising an exception. This confirms the dead try/except
around _parse_tool_args_to_dict was safe to remove."""
tool_call_map = {
0: {
"id": "call_bad",
"name": "internal_search",
"arguments": "not valid json {{{",
}
}
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
assert len(result) == 1
assert result[0].tool_args == {}
def test_none_arguments_returns_empty_dict(self) -> None:
tool_call_map = {
0: {
"id": "call_none",
"name": "internal_search",
"arguments": None,
}
}
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
assert len(result) == 1
assert result[0].tool_args == {}
def test_skips_entries_missing_id_or_name(self) -> None:
tool_call_map: dict[int, dict[str, Any]] = {
0: {"id": None, "name": "internal_search", "arguments": "{}"},
1: {"id": "call_1", "name": None, "arguments": "{}"},
2: {"id": "call_2", "name": "internal_search", "arguments": "{}"},
}
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
assert len(result) == 1
assert result[0].tool_call_id == "call_2"
def test_tab_index_auto_increments(self) -> None:
tool_call_map = {
0: {"id": "c1", "name": "tool_a", "arguments": "{}"},
1: {"id": "c2", "name": "tool_b", "arguments": "{}"},
}
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0)
assert result[0].placement.tab_index == 0
assert result[1].placement.tab_index == 1
def test_tab_index_override(self) -> None:
tool_call_map = {
0: {"id": "c1", "name": "tool_a", "arguments": "{}"},
1: {"id": "c2", "name": "tool_b", "arguments": "{}"},
}
result = _extract_tool_call_kickoffs(tool_call_map, turn_index=0, tab_index=5)
assert result[0].placement.tab_index == 5
assert result[1].placement.tab_index == 5
class TestXmlToolCallContentFilter:
def test_strips_function_calls_block_single_chunk(self) -> None:
f = _XmlToolCallContentFilter()
output = f.process(
"prefix "
'<function_calls><invoke name="internal_search">'
'<parameter name="queries" string="false">["Onyx docs"]</parameter>'
"</invoke></function_calls> suffix"
)
output += f.flush()
assert output == "prefix suffix"
def test_strips_function_calls_block_split_across_chunks(self) -> None:
f = _XmlToolCallContentFilter()
chunks = [
"Start ",
"<function_",
'calls><invoke name="internal_search">',
'<parameter name="queries" string="false">["Onyx docs"]',
"</parameter></invoke></function_calls>",
" End",
]
output = "".join(f.process(chunk) for chunk in chunks) + f.flush()
assert output == "Start End"
def test_preserves_non_tool_call_xml(self) -> None:
f = _XmlToolCallContentFilter()
output = f.process("A <tag>value</tag> B")
output += f.flush()
assert output == "A <tag>value</tag> B"
def test_does_not_strip_similar_tag_names(self) -> None:
f = _XmlToolCallContentFilter()
output = f.process(
"A <function_calls_v2><invoke>noop</invoke></function_calls_v2> B"
)
output += f.flush()
assert (
output == "A <function_calls_v2><invoke>noop</invoke></function_calls_v2> B"
)
class TestIncrementTurns:
"""Tests for the _increment_turns helper used by _close_reasoning_if_active."""
def test_increments_turn_index_when_no_sub_turn(self) -> None:
turn, sub = _increment_turns(0, None)
assert turn == 1
assert sub is None
def test_increments_sub_turn_when_present(self) -> None:
turn, sub = _increment_turns(3, 0)
assert turn == 3
assert sub == 1
def test_increments_sub_turn_from_nonzero(self) -> None:
turn, sub = _increment_turns(5, 2)
assert turn == 5
assert sub == 3
class TestResolveToolArguments:
"""Tests for the _resolve_tool_arguments helper."""
def test_dict_arguments(self) -> None:
obj = {"arguments": {"queries": ["test"]}}
assert _resolve_tool_arguments(obj) == {"queries": ["test"]}
def test_dict_parameters(self) -> None:
"""Falls back to 'parameters' key when 'arguments' is missing."""
obj = {"parameters": {"queries": ["test"]}}
assert _resolve_tool_arguments(obj) == {"queries": ["test"]}
def test_arguments_takes_precedence_over_parameters(self) -> None:
obj = {"arguments": {"a": 1}, "parameters": {"b": 2}}
assert _resolve_tool_arguments(obj) == {"a": 1}
def test_json_string_arguments(self) -> None:
obj = {"arguments": '{"queries": ["test"]}'}
assert _resolve_tool_arguments(obj) == {"queries": ["test"]}
def test_invalid_json_string_returns_empty_dict(self) -> None:
obj = {"arguments": "not valid json"}
assert _resolve_tool_arguments(obj) == {}
def test_no_arguments_or_parameters_returns_empty_dict(self) -> None:
obj = {"name": "some_tool"}
assert _resolve_tool_arguments(obj) == {}
def test_non_dict_non_string_arguments_returns_none(self) -> None:
"""When arguments resolves to a list or int, returns None."""
assert _resolve_tool_arguments({"arguments": [1, 2, 3]}) is None
assert _resolve_tool_arguments({"arguments": 42}) is None

View File

@@ -1,4 +1,3 @@
import logging
from unittest.mock import MagicMock
from uuid import uuid4
@@ -109,20 +108,14 @@ class TestScimDALUserMappings:
mock_db_session.delete.assert_called_once_with(mapping)
def test_delete_nonexistent_user_mapping_is_idempotent(
self,
scim_dal: ScimDAL,
mock_db_session: MagicMock,
caplog: pytest.LogCaptureFixture,
def test_delete_nonexistent_user_mapping_raises(
self, scim_dal: ScimDAL, mock_db_session: MagicMock
) -> None:
mock_db_session.get.return_value = None
with caplog.at_level(logging.WARNING):
with pytest.raises(ValueError, match="not found"):
scim_dal.delete_user_mapping(999)
mock_db_session.delete.assert_not_called()
assert "SCIM user mapping 999 not found" in caplog.text
def test_update_user_mapping_external_id(
self, scim_dal: ScimDAL, mock_db_session: MagicMock
) -> None:
@@ -170,16 +163,10 @@ class TestScimDALGroupMappings:
mock_db_session.delete.assert_called_once_with(mapping)
def test_delete_nonexistent_group_mapping_is_idempotent(
self,
scim_dal: ScimDAL,
mock_db_session: MagicMock,
caplog: pytest.LogCaptureFixture,
def test_delete_nonexistent_group_mapping_raises(
self, scim_dal: ScimDAL, mock_db_session: MagicMock
) -> None:
mock_db_session.get.return_value = None
with caplog.at_level(logging.WARNING):
with pytest.raises(ValueError, match="not found"):
scim_dal.delete_group_mapping(999)
mock_db_session.delete.assert_not_called()
assert "SCIM group mapping 999 not found" in caplog.text

View File

@@ -1,102 +0,0 @@
"""Shared fixtures for SCIM endpoint unit tests."""
from __future__ import annotations
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import ScimToken
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
@pytest.fixture
def mock_db_session() -> MagicMock:
"""A MagicMock standing in for a SQLAlchemy Session."""
return MagicMock(spec=Session)
@pytest.fixture
def mock_token() -> MagicMock:
"""A MagicMock standing in for a verified ScimToken."""
token = MagicMock(spec=ScimToken)
token.id = 1
return token
@pytest.fixture
def mock_dal() -> Generator[MagicMock, None, None]:
"""Patch ScimDAL construction in api module and yield the mock instance."""
with patch("ee.onyx.server.scim.api.ScimDAL") as cls:
dal = cls.return_value
# User defaults
dal.get_user.return_value = None
dal.get_user_by_email.return_value = None
dal.get_user_mapping_by_user_id.return_value = None
dal.get_user_mapping_by_external_id.return_value = None
dal.list_users.return_value = ([], 0)
# Group defaults
dal.get_group.return_value = None
dal.get_group_by_name.return_value = None
dal.get_group_mapping_by_group_id.return_value = None
dal.get_group_mapping_by_external_id.return_value = None
dal.get_group_members.return_value = []
dal.list_groups.return_value = ([], 0)
yield dal
def make_scim_user(**kwargs: Any) -> ScimUserResource:
"""Build a ScimUserResource with sensible defaults."""
defaults: dict[str, Any] = {
"userName": "test@example.com",
"externalId": "ext-default",
"active": True,
"name": ScimName(givenName="Test", familyName="User"),
}
defaults.update(kwargs)
return ScimUserResource(**defaults)
def make_scim_group(**kwargs: Any) -> ScimGroupResource:
"""Build a ScimGroupResource with sensible defaults."""
defaults: dict[str, Any] = {"displayName": "Engineering"}
defaults.update(kwargs)
return ScimGroupResource(**defaults)
def make_db_user(**kwargs: Any) -> MagicMock:
"""Build a mock User ORM object with configurable attributes."""
user = MagicMock(spec=User)
user.id = kwargs.get("id", uuid4())
user.email = kwargs.get("email", "test@example.com")
user.is_active = kwargs.get("is_active", True)
user.personal_name = kwargs.get("personal_name", "Test User")
user.role = kwargs.get("role", UserRole.BASIC)
return user
def make_db_group(**kwargs: Any) -> MagicMock:
"""Build a mock UserGroup ORM object with configurable attributes."""
group = MagicMock(spec=UserGroup)
group.id = kwargs.get("id", 1)
group.name = kwargs.get("name", "Engineering")
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
group.is_up_to_date = kwargs.get("is_up_to_date", True)
return group
def assert_scim_error(result: object, expected_status: int) -> None:
"""Assert *result* is a JSONResponse with the given status code."""
assert isinstance(result, JSONResponse)
assert result.status_code == expected_status

View File

@@ -1,132 +0,0 @@
"""Tests for SCIM admin token management endpoints."""
from datetime import datetime
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.api import create_scim_token
from ee.onyx.server.enterprise_settings.api import get_active_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from onyx.db.models import ScimToken
from onyx.db.models import User
@pytest.fixture
def mock_db_session() -> MagicMock:
return MagicMock(spec=Session)
@pytest.fixture
def scim_dal(mock_db_session: MagicMock) -> ScimDAL:
return ScimDAL(mock_db_session)
@pytest.fixture
def admin_user() -> User:
user = User(id=uuid4(), email="admin@test.com")
user.is_active = True
return user
def _make_token(token_id: int, name: str, *, is_active: bool = True) -> ScimToken:
return ScimToken(
id=token_id,
name=name,
hashed_token="h" * 64,
token_display="onyx_scim_****abcd",
is_active=is_active,
created_by_id=uuid4(),
created_at=datetime(2026, 1, 1),
last_used_at=None,
)
class TestGetActiveToken:
def test_returns_token_metadata(self, scim_dal: ScimDAL, admin_user: User) -> None:
token = _make_token(1, "prod-token")
scim_dal._session.scalar.return_value = token # type: ignore[attr-defined]
result = get_active_scim_token(_=admin_user, dal=scim_dal)
assert result.id == 1
assert result.name == "prod-token"
assert result.is_active is True
def test_raises_404_when_no_active_token(
self, scim_dal: ScimDAL, admin_user: User
) -> None:
scim_dal._session.scalar.return_value = None # type: ignore[attr-defined]
with pytest.raises(HTTPException) as exc_info:
get_active_scim_token(_=admin_user, dal=scim_dal)
assert exc_info.value.status_code == 404
class TestCreateToken:
@patch("ee.onyx.server.enterprise_settings.api.generate_scim_token")
def test_creates_token_and_revokes_previous(
self,
mock_generate: MagicMock,
scim_dal: ScimDAL,
admin_user: User,
) -> None:
mock_generate.return_value = ("raw_token_val", "hashed_val", "****abcd")
# Simulate one existing active token that should get revoked
existing = _make_token(1, "old-token", is_active=True)
scim_dal._session.scalars.return_value.all.return_value = [existing] # type: ignore[attr-defined]
# Simulate DB defaults that would be set on INSERT/flush
def fake_add(obj: ScimToken) -> None:
obj.id = 2
obj.is_active = True
obj.created_at = datetime(2026, 2, 1)
scim_dal._session.add.side_effect = fake_add # type: ignore[attr-defined]
body = ScimTokenCreate(name="new-token")
result = create_scim_token(body=body, user=admin_user, dal=scim_dal)
# Previous token was revoked (by create_token's internal revocation)
assert existing.is_active is False
# New token returned with raw value
assert result.raw_token == "raw_token_val"
assert result.name == "new-token"
assert result.is_active is True
# Session was committed
scim_dal._session.commit.assert_called_once() # type: ignore[attr-defined]
@patch("ee.onyx.server.enterprise_settings.api.generate_scim_token")
def test_creates_first_token_when_none_exist(
self,
mock_generate: MagicMock,
scim_dal: ScimDAL,
admin_user: User,
) -> None:
mock_generate.return_value = ("raw_token_val", "hashed_val", "****abcd")
# No existing tokens
scim_dal._session.scalars.return_value.all.return_value = [] # type: ignore[attr-defined]
def fake_add(obj: ScimToken) -> None:
obj.id = 1
obj.is_active = True
obj.created_at = datetime(2026, 2, 1)
scim_dal._session.add.side_effect = fake_add # type: ignore[attr-defined]
body = ScimTokenCreate(name="first-token")
result = create_scim_token(body=body, user=admin_user, dal=scim_dal)
assert result.raw_token == "raw_token_val"
assert result.name == "first-token"
assert result.is_active is True

View File

@@ -1,633 +0,0 @@
"""Unit tests for SCIM Group CRUD endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from fastapi import Response
from ee.onyx.server.scim.api import create_group
from ee.onyx.server.scim.api import delete_group
from ee.onyx.server.scim.api import get_group
from ee.onyx.server.scim.api import list_groups
from ee.onyx.server.scim.api import patch_group
from ee.onyx.server.scim.api import replace_group
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.patch import ScimPatchError
from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_group
from tests.unit.onyx.server.scim.conftest import make_scim_group
class TestListGroups:
"""Tests for GET /scim/v2/Groups."""
def test_empty_result(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.list_groups.return_value = ([], 0)
result = list_groups(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 0
assert result.Resources == []
def test_unsupported_filter_returns_400(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.list_groups.side_effect = ValueError(
"Unsupported filter attribute: userName"
)
result = list_groups(
filter='userName eq "x"',
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
def test_returns_groups_with_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5, name="Engineering")
uid = uuid4()
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
mock_dal.get_group_members.return_value = [(uid, "alice@example.com")]
result = list_groups(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 1
resource = result.Resources[0]
assert isinstance(resource, ScimGroupResource)
assert resource.displayName == "Engineering"
assert resource.externalId == "ext-g-1"
assert len(resource.members) == 1
assert resource.members[0].display == "alice@example.com"
class TestGetGroup:
"""Tests for GET /scim/v2/Groups/{group_id}."""
def test_returns_scim_resource(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5, name="Engineering")
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
result = get_group(
group_id="5",
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
assert result.displayName == "Engineering"
assert result.id == "5"
def test_non_integer_id_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
result = get_group(
group_id="not-a-number",
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group.return_value = None
result = get_group(
group_id="999",
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
class TestCreateGroup:
"""Tests for POST /scim/v2/Groups."""
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_success(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group_by_name.return_value = None
mock_validate.return_value = ([], None)
mock_dal.get_group_members.return_value = []
resource = make_scim_group(displayName="New Group")
result = create_group(
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
assert result.displayName == "New Group"
mock_dal.add_group.assert_called_once()
mock_dal.commit.assert_called_once()
def test_duplicate_name_returns_409(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group_by_name.return_value = make_db_group()
resource = make_scim_group()
result = create_group(
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 409)
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_invalid_member_returns_400(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group_by_name.return_value = None
mock_validate.return_value = ([], "Invalid member ID: bad-uuid")
resource = make_scim_group(members=[ScimGroupMember(value="bad-uuid")])
result = create_group(
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_nonexistent_member_returns_400(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group_by_name.return_value = None
uid = uuid4()
mock_validate.return_value = ([], f"Member(s) not found: {uid}")
resource = make_scim_group(members=[ScimGroupMember(value=str(uid))])
result = create_group(
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_creates_external_id_mapping(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group_by_name.return_value = None
mock_validate.return_value = ([], None)
mock_dal.get_group_members.return_value = []
resource = make_scim_group(externalId="ext-g-123")
result = create_group(
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
mock_dal.create_group_mapping.assert_called_once()
class TestReplaceGroup:
"""Tests for PUT /scim/v2/Groups/{group_id}."""
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_success(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5, name="Old Name")
mock_dal.get_group.return_value = group
mock_validate.return_value = ([], None)
mock_dal.get_group_members.return_value = []
resource = make_scim_group(displayName="New Name")
result = replace_group(
group_id="5",
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
mock_dal.update_group.assert_called_once_with(group, name="New Name")
mock_dal.replace_group_members.assert_called_once()
mock_dal.commit.assert_called_once()
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group.return_value = None
result = replace_group(
group_id="999",
group_resource=make_scim_group(),
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_invalid_member_returns_400(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mock_validate.return_value = ([], "Invalid member ID: bad")
resource = make_scim_group(members=[ScimGroupMember(value="bad")])
result = replace_group(
group_id="5",
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api._validate_and_parse_members")
def test_syncs_external_id(
self,
mock_validate: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mock_validate.return_value = ([], None)
mock_dal.get_group_members.return_value = []
resource = make_scim_group(externalId="new-ext")
replace_group(
group_id="5",
group_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
mock_dal.sync_group_external_id.assert_called_once_with(5, "new-ext")
class TestPatchGroup:
"""Tests for PATCH /scim/v2/Groups/{group_id}."""
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_rename(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5, name="Old Name")
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
patched = ScimGroupResource(id="5", displayName="New Name", members=[])
mock_apply.return_value = (patched, [], [])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="displayName",
value="New Name",
)
]
)
result = patch_group(
group_id="5",
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
mock_dal.update_group.assert_called_once_with(group, name="New Name")
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group.return_value = None
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="displayName",
value="X",
)
]
)
result = patch_group(
group_id="999",
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_patch_error_returns_error_response(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
mock_apply.side_effect = ScimPatchError("Unsupported path", 400)
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="badPath",
value="x",
)
]
)
result = patch_group(
group_id="5",
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_add_members(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
mock_dal.validate_member_ids.return_value = []
uid = str(uuid4())
patched = ScimGroupResource(
id="5",
displayName="Engineering",
members=[ScimGroupMember(value=uid)],
)
mock_apply.return_value = (patched, [uid], [])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.ADD,
path="members",
value=[{"value": uid}],
)
]
)
result = patch_group(
group_id="5",
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
mock_dal.validate_member_ids.assert_called_once()
mock_dal.upsert_group_members.assert_called_once()
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_add_nonexistent_member_returns_400(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
uid = uuid4()
patched = ScimGroupResource(
id="5",
displayName="Engineering",
members=[ScimGroupMember(value=str(uid))],
)
mock_apply.return_value = (patched, [str(uid)], [])
mock_dal.validate_member_ids.return_value = [uid]
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.ADD,
path="members",
value=[{"value": str(uid)}],
)
]
)
result = patch_group(
group_id="5",
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_remove_members(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
uid = str(uuid4())
patched = ScimGroupResource(id="5", displayName="Engineering", members=[])
mock_apply.return_value = (patched, [], [uid])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REMOVE,
path=f'members[value eq "{uid}"]',
)
]
)
result = patch_group(
group_id="5",
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
mock_dal.remove_group_members.assert_called_once()
class TestDeleteGroup:
"""Tests for DELETE /scim/v2/Groups/{group_id}."""
def test_success(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
mapping = MagicMock()
mapping.id = 1
mock_dal.get_group_mapping_by_group_id.return_value = mapping
result = delete_group(
group_id="5",
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, Response)
assert result.status_code == 204
mock_dal.delete_group_mapping.assert_called_once_with(1)
mock_dal.delete_group_with_members.assert_called_once_with(group)
mock_dal.commit.assert_called_once()
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_group.return_value = None
result = delete_group(
group_id="999",
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
def test_non_integer_id_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
result = delete_group(
group_id="abc",
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)

View File

@@ -1,521 +0,0 @@
"""Unit tests for SCIM User CRUD endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_user
from ee.onyx.server.scim.api import list_users
from ee.onyx.server.scim.api import patch_user
from ee.onyx.server.scim.api import replace_user
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import ScimPatchError
from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_user
from tests.unit.onyx.server.scim.conftest import make_scim_user
class TestListUsers:
"""Tests for GET /scim/v2/Users."""
def test_empty_result(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.list_users.return_value = ([], 0)
result = list_users(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 0
assert result.Resources == []
def test_returns_users_with_scim_shape(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user(email="alice@example.com", personal_name="Alice Smith")
mock_dal.list_users.return_value = ([(user, "ext-abc")], 1)
result = list_users(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 1
assert len(result.Resources) == 1
resource = result.Resources[0]
assert isinstance(resource, ScimUserResource)
assert resource.userName == "alice@example.com"
assert resource.externalId == "ext-abc"
def test_unsupported_filter_attribute_returns_400(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.list_users.side_effect = ValueError(
"Unsupported filter attribute: emails"
)
result = list_users(
filter='emails eq "x@y.com"',
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
def test_invalid_filter_syntax_returns_400(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
result = list_users(
filter="not a valid filter",
startIndex=1,
count=100,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
class TestGetUser:
"""Tests for GET /scim/v2/Users/{user_id}."""
def test_returns_scim_resource(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user(email="alice@example.com")
mock_dal.get_user.return_value = user
result = get_user(
user_id=str(user.id),
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "alice@example.com"
assert result.id == str(user.id)
def test_invalid_uuid_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
result = get_user(
user_id="not-a-uuid",
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
def test_user_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user.return_value = None
result = get_user(
user_id=str(uuid4()),
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
class TestCreateUser:
"""Tests for POST /scim/v2/Users."""
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_success(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(userName="new@example.com")
result = create_user(
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "new@example.com"
mock_dal.add_user.assert_called_once()
mock_dal.commit.assert_called_once()
def test_missing_external_id_returns_400(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
resource = make_scim_user(externalId=None)
result = create_user(
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_duplicate_email_returns_409(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user_by_email.return_value = make_db_user()
resource = make_scim_user()
result = create_user(
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 409)
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_integrity_error_returns_409(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user_by_email.return_value = None
mock_dal.add_user.side_effect = IntegrityError("dup", {}, Exception())
resource = make_scim_user()
result = create_user(
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 409)
mock_dal.rollback.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability")
def test_seat_limit_returns_403(
self,
mock_seats: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
mock_seats.return_value = "Seat limit reached"
resource = make_scim_user()
result = create_user(
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 403)
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_creates_external_id_mapping(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(externalId="ext-123")
result = create_user(
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.externalId == "ext-123"
mock_dal.create_user_mapping.assert_called_once()
class TestReplaceUser:
"""Tests for PUT /scim/v2/Users/{user_id}."""
def test_success(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user(email="old@example.com")
mock_dal.get_user.return_value = user
resource = make_scim_user(
userName="new@example.com",
name=ScimName(givenName="New", familyName="Name"),
)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
mock_dal.update_user.assert_called_once()
mock_dal.commit.assert_called_once()
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user.return_value = None
result = replace_user(
user_id=str(uuid4()),
user_resource=make_scim_user(),
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
@patch("ee.onyx.server.scim.api._check_seat_availability")
def test_reactivation_checks_seats(
self,
mock_seats: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user(is_active=False)
mock_dal.get_user.return_value = user
mock_seats.return_value = "No seats"
resource = make_scim_user(active=True)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 403)
mock_seats.assert_called_once()
def test_syncs_external_id(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
resource = make_scim_user(externalId=None)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
mock_dal.sync_user_external_id.assert_called_once_with(user.id, None)
class TestPatchUser:
"""Tests for PATCH /scim/v2/Users/{user_id}."""
def test_deactivate(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user(is_active=True)
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="active",
value=False,
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
mock_dal.update_user.assert_called_once()
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user.return_value = None
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="active",
value=False,
)
]
)
result = patch_user(
user_id=str(uuid4()),
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
@patch("ee.onyx.server.scim.api.apply_user_patch")
def test_patch_error_returns_error_response(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
mock_apply.side_effect = ScimPatchError("Bad op", 400)
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REMOVE,
path="userName",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 400)
class TestDeleteUser:
"""Tests for DELETE /scim/v2/Users/{user_id}."""
def test_success(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user(is_active=True)
mock_dal.get_user.return_value = user
mapping = MagicMock()
mapping.id = 1
mock_dal.get_user_mapping_by_user_id.return_value = mapping
result = delete_user(
user_id=str(user.id),
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, Response)
assert result.status_code == 204
mock_dal.deactivate_user.assert_called_once_with(user)
mock_dal.delete_user_mapping.assert_called_once_with(1)
mock_dal.commit.assert_called_once()
def test_not_found_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
mock_dal.get_user.return_value = None
result = delete_user(
user_id=str(uuid4()),
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
def test_invalid_uuid_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
) -> None:
result = delete_user(
user_id="not-a-uuid",
_token=mock_token,
db_session=mock_db_session,
)
assert_scim_error(result, 404)

View File

@@ -1,171 +0,0 @@
"""Unit tests for Prometheus instrumentation module."""
import threading
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
from prometheus_client import CollectorRegistry
from prometheus_client import Gauge
from onyx.server.prometheus_instrumentation import _slow_request_callback
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
def _make_info(
duration: float,
method: str = "GET",
handler: str = "/api/test",
status: str = "200",
) -> Any:
"""Build a fake metrics Info object matching the instrumentator's Info shape."""
return MagicMock(
modified_duration=duration,
method=method,
modified_handler=handler,
modified_status=status,
)
def test_slow_request_callback_increments_above_threshold() -> None:
with patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter:
mock_labels = MagicMock()
mock_counter.labels.return_value = mock_labels
info = _make_info(
duration=2.0, method="POST", handler="/api/chat", status="200"
)
_slow_request_callback(info)
mock_counter.labels.assert_called_once_with(
method="POST", handler="/api/chat", status="200"
)
mock_labels.inc.assert_called_once()
def test_slow_request_callback_skips_below_threshold() -> None:
with patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter:
info = _make_info(duration=0.5)
_slow_request_callback(info)
mock_counter.labels.assert_not_called()
def test_slow_request_callback_skips_at_exact_threshold() -> None:
with (
patch(
"onyx.server.prometheus_instrumentation.SLOW_REQUEST_THRESHOLD_SECONDS", 1.0
),
patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter,
):
info = _make_info(duration=1.0)
_slow_request_callback(info)
mock_counter.labels.assert_not_called()
def test_setup_attaches_instrumentator_to_app() -> None:
with patch("onyx.server.prometheus_instrumentation.Instrumentator") as mock_cls:
mock_instance = MagicMock()
mock_instance.instrument.return_value = mock_instance
mock_cls.return_value = mock_instance
app = FastAPI()
setup_prometheus_metrics(app)
mock_cls.assert_called_once_with(
should_group_status_codes=False,
should_ignore_untemplated=False,
should_group_untemplated=True,
should_instrument_requests_inprogress=True,
inprogress_labels=True,
excluded_handlers=["/health", "/metrics", "/openapi.json"],
)
mock_instance.add.assert_called_once()
mock_instance.instrument.assert_called_once_with(app)
mock_instance.expose.assert_called_once_with(app)
def test_inprogress_gauge_increments_during_request() -> None:
"""Verify the in-progress gauge goes up while a request is in flight."""
registry = CollectorRegistry()
gauge = Gauge(
"http_requests_inprogress_test",
"In-progress requests",
["method", "handler"],
registry=registry,
)
request_started = threading.Event()
request_release = threading.Event()
app = FastAPI()
@app.get("/slow")
def slow_endpoint() -> dict:
gauge.labels(method="GET", handler="/slow").inc()
request_started.set()
request_release.wait(timeout=5)
gauge.labels(method="GET", handler="/slow").dec()
return {"status": "done"}
client = TestClient(app, raise_server_exceptions=False)
def make_request() -> None:
client.get("/slow")
thread = threading.Thread(target=make_request)
thread.start()
request_started.wait(timeout=5)
assert gauge.labels(method="GET", handler="/slow")._value.get() == 1.0
request_release.set()
thread.join(timeout=5)
assert gauge.labels(method="GET", handler="/slow")._value.get() == 0.0
def test_inprogress_gauge_tracks_concurrent_requests() -> None:
"""Verify the gauge correctly counts multiple concurrent in-flight requests."""
registry = CollectorRegistry()
gauge = Gauge(
"http_requests_inprogress_concurrent_test",
"In-progress requests",
["method", "handler"],
registry=registry,
)
# 3 parties: 2 request threads + main thread
barrier = threading.Barrier(3)
release = threading.Event()
app = FastAPI()
@app.get("/concurrent")
def concurrent_endpoint() -> dict:
gauge.labels(method="GET", handler="/concurrent").inc()
barrier.wait(timeout=5)
release.wait(timeout=5)
gauge.labels(method="GET", handler="/concurrent").dec()
return {"status": "done"}
client = TestClient(app, raise_server_exceptions=False)
def make_request() -> None:
client.get("/concurrent")
t1 = threading.Thread(target=make_request)
t2 = threading.Thread(target=make_request)
t1.start()
t2.start()
# All 3 threads meet here — both requests are in-flight
barrier.wait(timeout=5)
assert gauge.labels(method="GET", handler="/concurrent")._value.get() == 2.0
release.set()
t1.join(timeout=5)
t2.join(timeout=5)
assert gauge.labels(method="GET", handler="/concurrent")._value.get() == 0.0

View File

@@ -1,88 +0,0 @@
# Onyx Prometheus Metrics Reference
## API Server Metrics
These metrics are exposed at `GET /metrics` on the API server.
### Built-in (via `prometheus-fastapi-instrumentator`)
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (few buckets for aggregation) |
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
### Custom (via `onyx.server.prometheus_instrumentation`)
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_api_slow_requests_total` | Counter | `method`, `handler`, `status` | Requests exceeding `SLOW_REQUEST_THRESHOLD_SECONDS` (default 1s) |
### Configuration
| Env Var | Default | Description |
|---------|---------|-------------|
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
### Instrumentator Settings
- `should_group_status_codes=False` — Reports exact HTTP status codes (e.g. 401, 403, 500)
- `should_instrument_requests_inprogress=True` — Enables the in-progress request gauge
- `inprogress_labels=True` — Breaks down in-progress gauge by `method` and `handler`
- `excluded_handlers=["/health", "/metrics", "/openapi.json"]` — Excludes noisy endpoints from metrics
## Example PromQL Queries
### Which endpoints are saturated right now?
```promql
# Top 10 endpoints by in-progress requests
topk(10, http_requests_inprogress)
```
### What's the P99 latency per endpoint?
```promql
# P99 latency by handler over the last 5 minutes
histogram_quantile(0.99, sum by (handler, le) (rate(http_request_duration_seconds_bucket[5m])))
```
### Which endpoints have the highest request rate?
```promql
# Requests per second by handler, top 10
topk(10, sum by (handler) (rate(http_requests_total[5m])))
```
### Which endpoints are returning errors?
```promql
# 5xx error rate by handler
sum by (handler) (rate(http_requests_total{status=~"5.."}[5m]))
```
### Slow request hotspots
```promql
# Slow requests per minute by handler
sum by (handler) (rate(onyx_api_slow_requests_total[5m])) * 60
```
### Latency trending up?
```promql
# Compare P50 latency now vs 1 hour ago
histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m])))
-
histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m] offset 1h)))
```
### Overall request throughput
```promql
# Total requests per second across all endpoints
sum(rate(http_requests_total[5m]))
```

1
web/.gitignore vendored
View File

@@ -38,7 +38,6 @@ next-env.d.ts
# playwright testing temp files
/admin*_auth.json
/worker*_auth.json
/user_auth.json
/build-archive.log
/test-results

View File

@@ -0,0 +1,242 @@
"use client";
import { useState } from "react";
import Button from "@/refresh-components/buttons/Button";
import { Callout } from "@/components/ui/callout";
import Text from "@/components/ui/text";
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
import { toast } from "@/hooks/useToast";
import { structureValue } from "@/lib/llm/utils";
import { LlmDescriptor, useLlmManager } from "@/lib/hooks";
import Separator from "@/refresh-components/Separator";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { cn } from "@/lib/utils";
import { useCurrentAgent } from "@/hooks/useAgents";
import { useSearchParams } from "next/navigation";
import { useChatSessionStore } from "@/app/app/stores/useChatSessionStore";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import { copyAll } from "@/app/app/message/copyingUtils";
import { SvgCopy, SvgShare } from "@opal/icons";
function buildShareLink(chatSessionId: string) {
const baseUrl = `${window.location.protocol}//${window.location.host}`;
return `${baseUrl}/app/shared/${chatSessionId}`;
}
async function generateShareLink(chatSessionId: string) {
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ sharing_status: "public" }),
});
if (response.ok) {
return buildShareLink(chatSessionId);
}
return null;
}
async function generateSeedLink(
message?: string,
assistantId?: number,
modelOverride?: LlmDescriptor
) {
const baseUrl = `${window.location.protocol}//${window.location.host}`;
const model = modelOverride
? structureValue(
modelOverride.name,
modelOverride.provider,
modelOverride.modelName
)
: null;
return `${baseUrl}/app${
message
? `?${SEARCH_PARAM_NAMES.USER_PROMPT}=${encodeURIComponent(message)}`
: ""
}${
assistantId
? `${message ? "&" : "?"}${SEARCH_PARAM_NAMES.PERSONA_ID}=${assistantId}`
: ""
}${
model
? `${message || assistantId ? "&" : "?"}${
SEARCH_PARAM_NAMES.STRUCTURED_MODEL
}=${encodeURIComponent(model)}`
: ""
}${message ? `&${SEARCH_PARAM_NAMES.SEND_ON_LOAD}=true` : ""}`;
}
async function deleteShareLink(chatSessionId: string) {
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ sharing_status: "private" }),
});
return response.ok;
}
interface ShareChatSessionModalProps {
chatSession: ChatSession;
onClose: () => void;
}
export default function ShareChatSessionModal({
chatSession,
onClose,
}: ShareChatSessionModalProps) {
const [shareLink, setShareLink] = useState<string>(
chatSession.shared_status === ChatSessionSharedStatus.Public
? buildShareLink(chatSession.id)
: ""
);
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const currentAgent = useCurrentAgent();
const searchParams = useSearchParams();
const message = searchParams?.get(SEARCH_PARAM_NAMES.USER_PROMPT) || "";
const llmManager = useLlmManager(chatSession, currentAgent || undefined);
const updateCurrentChatSessionSharedStatus = useChatSessionStore(
(state) => state.updateCurrentChatSessionSharedStatus
);
return (
<>
<ConfirmationModalLayout
icon={SvgShare}
title="Share Chat"
onClose={onClose}
submit={<Button onClick={onClose}>Share</Button>}
>
{shareLink ? (
<div>
<Text>
This chat session is currently shared. Anyone in your team can
view the message history using the following link:
</Text>
<div className="flex items-center mt-2">
{/* <CopyButton content={shareLink} /> */}
<CopyIconButton
getCopyText={() => shareLink}
prominence="secondary"
/>
<a
href={shareLink}
target="_blank"
className={cn(
"underline mt-1 ml-1 text-sm my-auto",
"text-action-link-05"
)}
rel="noreferrer"
>
{shareLink}
</a>
</div>
<Separator />
<Text className={cn("mb-4")}>
Click the button below to make the chat private again.
</Text>
<Button
onClick={async () => {
const success = await deleteShareLink(chatSession.id);
if (success) {
setShareLink("");
updateCurrentChatSessionSharedStatus(
ChatSessionSharedStatus.Private
);
} else {
alert("Failed to delete share link");
}
}}
danger
>
Delete Share Link
</Button>
</div>
) : (
<div className="flex flex-col gap-2">
<Callout type="warning" title="Warning">
Please make sure that all content in this chat is safe to share
with the whole team.
</Callout>
<Button
leftIcon={SvgCopy}
onClick={async () => {
// NOTE: for "insecure" non-https setup, the `navigator.clipboard.writeText` may fail
// as the browser may not allow the clipboard to be accessed.
try {
const shareLink = await generateShareLink(chatSession.id);
if (!shareLink) {
alert("Failed to generate share link");
} else {
setShareLink(shareLink);
updateCurrentChatSessionSharedStatus(
ChatSessionSharedStatus.Public
);
copyAll(shareLink);
}
} catch (e) {
console.error(e);
}
}}
secondary
>
Generate and Copy Share Link
</Button>
</div>
)}
<Separator className={cn("my-4")} />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
title="Advanced Options"
/>
{showAdvancedOptions && (
<div className="flex flex-col gap-2">
<Callout type="notice" title="Seed New Chat">
Generate a link to a new chat session with the same settings as
this chat (including the assistant and model).
</Callout>
<Button
leftIcon={SvgCopy}
onClick={async () => {
try {
const seedLink = await generateSeedLink(
message,
currentAgent?.id,
llmManager.currentLlm
);
if (!seedLink) {
toast.error("Failed to generate seed link");
} else {
navigator.clipboard.writeText(seedLink);
copyAll(seedLink);
toast.success("Link copied to clipboard!");
}
} catch (e) {
console.error(e);
alert("Failed to generate or copy link.");
}
}}
secondary
>
Generate and Copy Seed Link
</Button>
</div>
)}
</ConfirmationModalLayout>
</>
);
}

View File

@@ -186,7 +186,6 @@ export interface BackendChatSession {
current_temperature_override: number | null;
current_alternate_model?: string;
owner_name: string | null;
packets: Packet[][];
}

View File

@@ -192,10 +192,10 @@ const HumanMessage = React.memo(function HumanMessage({
/>
) : typeof content === "string" ? (
<>
<div className="md:max-w-[37.5rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
<div className="md:max-w-[25rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
<div
className={
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
"max-w-[25rem] whitespace-break-spaces rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
}
onCopy={(e) => {
const selection = window.getSelection();

View File

@@ -20,6 +20,16 @@ import {
import { openDocument } from "@/lib/search/utils";
import { ensureHrefProtocol } from "@/lib/utils";
function isSameOriginUrl(url: string): boolean {
if (!url.startsWith("http")) return true;
try {
if (typeof window === "undefined") return false;
return new URL(url).origin === window.location.origin;
} catch {
return false;
}
}
export const MemoizedAnchor = memo(
({
docs,
@@ -178,6 +188,26 @@ export const MemoizedLink = memo(
const url = ensureHrefProtocol(href);
// Check if the link is to a file on the backend
const isChatFile = url?.includes("/api/chat/file/") && isSameOriginUrl(url);
if (isChatFile && updatePresentingDocument) {
const fileId = url!.split("/api/chat/file/")[1]?.split(/[?#]/)[0] || "";
const filename = value?.toString() || "download";
return (
<a
onClick={() =>
updatePresentingDocument({
document_id: fileId,
semantic_identifier: filename,
} as OnyxDocument)
}
className="cursor-pointer text-link hover:text-link-hover"
>
{rest.children}
</a>
);
}
return (
<a
href={url}

View File

@@ -209,10 +209,7 @@ export default function MessageToolbar({
<FeedbackModal {...feedbackModalProps!} />
</modal.Provider>
<div
data-testid="AgentMessage/toolbar"
className="flex md:flex-row justify-between items-center w-full transition-transform duration-300 ease-in-out transform opacity-100 pl-1"
>
<div className="flex md:flex-row justify-between items-center w-full transition-transform duration-300 ease-in-out transform opacity-100 pl-1">
<TooltipGroup>
<div className="flex items-center">
{includeMessageSwitcher && (

View File

@@ -15,7 +15,6 @@ import TextViewModal from "@/sections/modals/TextViewModal";
import { UNNAMED_CHAT } from "@/lib/constants";
import Text from "@/refresh-components/texts/Text";
import useOnMount from "@/hooks/useOnMount";
import SharedAppInputBar from "@/sections/input/SharedAppInputBar";
export interface SharedChatDisplayProps {
chatSession: BackendChatSession | null;
@@ -70,78 +69,65 @@ export default function SharedChatDisplay({
/>
)}
<div className="flex flex-col h-full w-full overflow-hidden">
<div className="flex-1 flex flex-col items-center overflow-y-auto">
<div className="sticky top-0 z-10 flex items-center justify-between w-full bg-background-tint-01 px-8 py-4">
<Text as="p" text04 headingH2>
{chatSession.description || UNNAMED_CHAT}
</Text>
<div className="flex flex-col items-end">
<Text as="p" text03 secondaryBody>
Shared on {humanReadableFormat(chatSession.time_created)}
</Text>
{chatSession.owner_name && (
<Text as="p" text03 secondaryBody>
by {chatSession.owner_name}
</Text>
)}
</div>
</div>
<div className="flex flex-col items-center h-full w-full overflow-hidden overflow-y-scroll">
<div className="sticky top-0 z-10 flex flex-col w-full bg-background-tint-01 px-8 py-4">
<Text as="p" headingH2>
{chatSession.description || UNNAMED_CHAT}
</Text>
<Text as="p" text03>
{humanReadableFormat(chatSession.time_created)}
</Text>
</div>
{isMounted ? (
<div className="w-[min(50rem,100%)]">
{messages.map((message, i) => {
if (message.type === "user") {
return (
<HumanMessage
key={message.messageId}
content={message.message}
files={message.files}
nodeId={message.nodeId}
/>
);
} else if (message.type === "assistant") {
return (
<AgentMessage
key={message.messageId}
rawPackets={message.packets}
chatState={{
assistant: persona,
docs: message.documents,
citations: message.citations,
setPresentingDocument: setPresentingDocument,
overriddenModel: message.overridden_model,
}}
nodeId={message.nodeId}
llmManager={null}
otherMessagesCanSwitchTo={undefined}
onMessageSelection={undefined}
/>
);
} else {
// Error message case
return (
<div key={message.messageId} className="py-5 ml-4 lg:px-5">
<div className="mx-auto w-[90%] max-w-message-max">
<p className="text-status-text-error-05 text-sm my-auto">
{message.message}
</p>
</div>
{isMounted ? (
<div className="w-[min(50rem,100%)]">
{messages.map((message, i) => {
if (message.type === "user") {
return (
<HumanMessage
key={message.messageId}
content={message.message}
files={message.files}
nodeId={message.nodeId}
/>
);
} else if (message.type === "assistant") {
return (
<AgentMessage
key={message.messageId}
rawPackets={message.packets}
chatState={{
assistant: persona,
docs: message.documents,
citations: message.citations,
setPresentingDocument: setPresentingDocument,
overriddenModel: message.overridden_model,
}}
nodeId={message.nodeId}
llmManager={null}
otherMessagesCanSwitchTo={undefined}
onMessageSelection={undefined}
/>
);
} else {
// Error message case
return (
<div key={message.messageId} className="py-5 ml-4 lg:px-5">
<div className="mx-auto w-[90%] max-w-message-max">
<p className="text-status-text-error-05 text-sm my-auto">
{message.message}
</p>
</div>
);
}
})}
</div>
) : (
<div className="h-full w-full flex items-center justify-center">
<OnyxInitializingLoader />
</div>
)}
</div>
<div className="w-full max-w-[50rem] mx-auto px-4 pb-4">
<SharedAppInputBar />
</div>
</div>
);
}
})}
</div>
) : (
<div className="h-full w-full flex items-center justify-center">
<OnyxInitializingLoader />
</div>
)}
</div>
</>
);

View File

@@ -12,8 +12,7 @@ export type AppFocusType =
| { type: "agent" | "project" | "chat"; id: string }
| "new-session"
| "more-agents"
| "user-settings"
| "shared-chat";
| "user-settings";
export class AppFocus {
constructor(public value: AppFocusType) {}
@@ -30,10 +29,6 @@ export class AppFocus {
return typeof this.value === "object" && this.value.type === "chat";
}
isSharedChat(): boolean {
return this.value === "shared-chat";
}
isNewSession(): boolean {
return this.value === "new-session";
}
@@ -54,7 +49,6 @@ export class AppFocus {
| "agent"
| "project"
| "chat"
| "shared-chat"
| "new-session"
| "more-agents"
| "user-settings" {
@@ -66,11 +60,6 @@ export default function useAppFocus(): AppFocus {
const pathname = usePathname();
const searchParams = useSearchParams();
// Check if we're viewing a shared chat
if (pathname.startsWith("/app/shared/")) {
return new AppFocus("shared-chat");
}
// Check if we're on the user settings page
if (pathname.startsWith("/app/settings")) {
return new AppFocus("user-settings");

View File

@@ -1,63 +0,0 @@
"use client";
import { useState, useEffect } from "react";
const SELECTOR = "[data-main-container]";
interface ContainerCenter {
centerX: number | null;
centerY: number | null;
hasContainerCenter: boolean;
}
function measure(el: HTMLElement): { x: number; y: number } {
const rect = el.getBoundingClientRect();
return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
}
/**
* Tracks the center point of the `[data-main-container]` element so that
* portaled overlays (modals, command menus) can center relative to the main
* content area rather than the full viewport.
*
* Returns `{ centerX, centerY, hasContainerCenter }`.
* When the container is not present (e.g. pages without `AppLayouts.Root`),
* both center values are `null` and `hasContainerCenter` is `false`, allowing
* callers to fall back to standard viewport centering.
*
* Uses a lazy `useState` initializer so the first render already has the
* correct values (no flash), and a `ResizeObserver` to stay reactive when
* the sidebar folds/unfolds.
*/
export default function useContainerCenter(): ContainerCenter {
const [center, setCenter] = useState<{ x: number | null; y: number | null }>(
() => {
if (typeof document === "undefined") return { x: null, y: null };
const el = document.querySelector<HTMLElement>(SELECTOR);
if (!el) return { x: null, y: null };
const m = measure(el);
return { x: m.x, y: m.y };
}
);
useEffect(() => {
const container = document.querySelector<HTMLElement>(SELECTOR);
if (!container) return;
const update = () => {
const m = measure(container);
setCenter({ x: m.x, y: m.y });
};
update();
const observer = new ResizeObserver(update);
observer.observe(container);
return () => observer.disconnect();
}, []);
return {
centerX: center.x,
centerY: center.y,
hasContainerCenter: center.x !== null && center.y !== null,
};
}

View File

@@ -27,7 +27,7 @@ import Button from "@/refresh-components/buttons/Button";
import { useCallback, useMemo, useState, useEffect } from "react";
import { useAppBackground } from "@/providers/AppBackgroundProvider";
import { useTheme } from "next-themes";
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
import ShareChatSessionModal from "@/app/app/components/modal/ShareChatSessionModal";
import IconButton from "@/refresh-components/buttons/IconButton";
import LineItem from "@/refresh-components/buttons/LineItem";
import { useProjectsContext } from "@/providers/ProjectsContext";
@@ -112,10 +112,6 @@ function Header() {
const customHeaderContent =
settings?.enterpriseSettings?.custom_header_content;
// Some pages don't want the custom header content, namely every page except Chat, Search, and
// NewSession. The header provides features such as the open sidebar button on mobile which pages
// without this content still use.
const pageWithHeaderContent = appFocus.isChat() || appFocus.isNewSession();
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
@@ -362,7 +358,7 @@ function Header() {
*/}
<div className="flex-1 flex flex-col items-center overflow-hidden">
<Text text03 className="text-center w-full">
{pageWithHeaderContent && customHeaderContent}
{customHeaderContent}
</Text>
</div>
@@ -379,7 +375,6 @@ function Header() {
transient={showShareModal}
tertiary
onClick={() => setShowShareModal(true)}
aria-label="share-chat-button"
>
Share Chat
</Button>
@@ -515,12 +510,8 @@ function Root({ children, enableBackground }: AppRootProps) {
return (
/* NOTE: Some elements, markdown tables in particular, refer to this `@container` in order to
breakout of their immediate containers using cqw units.
The `data-main-container` attribute is used by portaled elements (e.g. CommandMenu) to
render inside this container so they can be centered relative to the main content area
rather than the full viewport (which would include the sidebar).
*/
<div
data-main-container
className={cn(
"@container flex flex-col h-full w-full relative overflow-hidden",
showBackground && "bg-cover bg-center bg-fixed"
@@ -573,7 +564,7 @@ function Root({ children, enableBackground }: AppRootProps) {
)}
<div className="z-app-layout">
{!appFocus.isSharedChat() && <Header />}
<Header />
</div>
<div className="z-app-layout flex-1 overflow-auto h-full w-full">
{children}

View File

@@ -216,7 +216,7 @@ function SettingsHeader({
ref={headerRef}
className={cn(
"sticky top-0 z-settings-header w-full bg-background-tint-01",
backButton ? "md:pt-4" : "md:pt-10"
backButton ? "pt-4" : "pt-10"
)}
>
{backButton && (

40
web/src/lib/languages.ts Normal file
View File

@@ -0,0 +1,40 @@
import * as languages from "linguist-languages";
interface LinguistLanguage {
name: string;
type: string;
extensions?: string[];
filenames?: string[];
}
// Build extension → language name and filename → language name maps at module load
const extensionMap = new Map<string, string>();
const filenameMap = new Map<string, string>();
for (const lang of Object.values(languages) as LinguistLanguage[]) {
if (lang.type !== "programming") continue;
const name = lang.name.toLowerCase();
for (const ext of lang.extensions ?? []) {
// First language to claim an extension wins
if (!extensionMap.has(ext)) {
extensionMap.set(ext, name);
}
}
for (const filename of lang.filenames ?? []) {
if (!filenameMap.has(filename.toLowerCase())) {
filenameMap.set(filename.toLowerCase(), name);
}
}
}
/**
* Returns the language name for a given file name, or null if it's not a
* recognised code file. Looks up by extension first, then by exact filename
* (e.g. "Dockerfile", "Makefile"). Runs in O(1).
*/
export function getCodeLanguage(name: string): string | null {
const lower = name.toLowerCase();
const ext = lower.match(/\.[^.]+$/)?.[0];
return (ext && extensionMap.get(ext)) ?? filenameMap.get(lower) ?? null;
}

View File

@@ -9,7 +9,6 @@ import { Button } from "@opal/components";
import { SvgX } from "@opal/icons";
import { WithoutStyles } from "@/types";
import { Section, SectionProps } from "@/layouts/general-layouts";
import useContainerCenter from "@/hooks/useContainerCenter";
/**
* Modal Root Component
@@ -265,8 +264,6 @@ const ModalContent = React.forwardRef<
contentRef(node);
};
const { centerX, centerY, hasContainerCenter } = useContainerCenter();
const animationClasses = cn(
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0",
"data-[state=open]:zoom-in-95 data-[state=closed]:zoom-out-95",
@@ -274,22 +271,6 @@ const ModalContent = React.forwardRef<
"duration-200"
);
const containerStyle: React.CSSProperties | undefined = hasContainerCenter
? ({
left: centerX,
top: centerY,
"--tw-enter-translate-x": "-50%",
"--tw-exit-translate-x": "-50%",
"--tw-enter-translate-y": "-50%",
"--tw-exit-translate-y": "-50%",
} as React.CSSProperties)
: undefined;
const positionClasses = cn(
"fixed -translate-x-1/2 -translate-y-1/2",
!hasContainerCenter && "left-1/2 top-1/2"
);
const dialogEventHandlers = {
onOpenAutoFocus: (e: Event) => {
resetState();
@@ -334,9 +315,8 @@ const ModalContent = React.forwardRef<
{...dialogEventHandlers}
>
<div
style={containerStyle}
className={cn(
positionClasses,
"fixed left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2",
"z-modal",
"flex flex-col gap-4 items-center",
"max-w-[calc(100dvw-2rem)] max-h-[calc(100dvh-2rem)]",
@@ -354,10 +334,8 @@ const ModalContent = React.forwardRef<
// Without bottomSlot: original single-element rendering
<DialogPrimitive.Content
ref={handleRef}
style={containerStyle}
className={cn(
positionClasses,
"overflow-hidden",
"fixed left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2 overflow-hidden",
"z-modal",
background === "gray"
? "bg-background-tint-01"

View File

@@ -32,7 +32,7 @@ const sizeClasses = {
container: "rounded-04 p-0.5 gap-0.5",
},
tag: {
container: "rounded-08 h-[2.25rem] min-w-[2.25rem] p-2 gap-1",
container: "rounded-08 p-1 gap-1",
},
} as const;

View File

@@ -10,7 +10,6 @@ import React, {
} from "react";
import * as DialogPrimitive from "@radix-ui/react-dialog";
import * as VisuallyHidden from "@radix-ui/react-visually-hidden";
import useContainerCenter from "@/hooks/useContainerCenter";
import { cn } from "@/lib/utils";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
@@ -367,11 +366,10 @@ const CommandMenuContent = React.forwardRef<
CommandMenuContentProps
>(({ children }, ref) => {
const { handleKeyDown } = useCommandMenuContext();
const { centerX, hasContainerCenter } = useContainerCenter();
return (
<DialogPrimitive.Portal>
{/* Overlay - fixed to full viewport, hidden from assistive technology */}
{/* Overlay - hidden from assistive technology */}
<DialogPrimitive.Overlay
aria-hidden="true"
className={cn(
@@ -380,23 +378,12 @@ const CommandMenuContent = React.forwardRef<
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0"
)}
/>
{/* Content - centered within the main container when available,
otherwise falls back to viewport centering */}
{/* Content */}
<DialogPrimitive.Content
ref={ref}
onKeyDown={handleKeyDown}
style={
hasContainerCenter
? ({
left: centerX,
"--tw-enter-translate-x": "-50%",
"--tw-exit-translate-x": "-50%",
} as React.CSSProperties)
: undefined
}
className={cn(
"fixed top-[72px]",
hasContainerCenter ? "-translate-x-1/2" : "inset-x-0 mx-auto",
"fixed inset-x-0 top-[72px] mx-auto",
"z-modal",
"bg-background-tint-00 border rounded-16 shadow-2xl outline-none",
"flex flex-col overflow-hidden",

View File

@@ -26,6 +26,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { useUser } from "@/providers/UserProvider";
import NoAssistantModal from "@/components/modals/NoAssistantModal";
import TextViewModal from "@/sections/modals/TextViewModal";
import CodeViewModal from "@/sections/modals/CodeViewModal";
import { getCodeLanguage } from "@/lib/languages";
import Modal from "@/refresh-components/Modal";
import { useSendMessageToParent } from "@/lib/extension/utils";
import { SUBMIT_MESSAGE_TYPES } from "@/lib/extension/constants";
@@ -682,12 +684,18 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
</div>
)}
{presentingDocument && (
<TextViewModal
presentingDocument={presentingDocument}
onClose={() => setPresentingDocument(null)}
/>
)}
{presentingDocument &&
(getCodeLanguage(presentingDocument.semantic_identifier || "") ? (
<CodeViewModal
presentingDocument={presentingDocument}
onClose={() => setPresentingDocument(null)}
/>
) : (
<TextViewModal
presentingDocument={presentingDocument}
onClose={() => setPresentingDocument(null)}
/>
))}
{stackTraceModalContent && (
<ExceptionTraceModal

View File

@@ -346,7 +346,6 @@ const ChatScrollContainer = React.memo(
<div
key={sessionId}
ref={scrollContainerRef}
data-testid="chat-scroll-container"
className="flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden default-scrollbar"
onScroll={handleScroll}
style={{

View File

@@ -1,55 +0,0 @@
"use client";
import Text from "@/refresh-components/texts/Text";
import { Button, OpenButton } from "@opal/components";
import { OpenAISVG } from "@/components/icons/icons";
import {
SvgPlusCircle,
SvgArrowUp,
SvgSliders,
SvgHourglass,
SvgEditBig,
} from "@opal/icons";
export default function SharedAppInputBar() {
return (
<div className="relative w-full">
<div className="w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16">
{/* Textarea area */}
<div className="flex flex-row items-center w-full">
<Text text03 className="w-full px-3 pt-3 pb-2 select-none">
How can Onyx help you today
</Text>
</div>
{/* Bottom toolbar */}
<div className="flex justify-between items-center w-full p-1 min-h-[40px]">
{/* Left side controls */}
<div className="flex flex-row items-center">
<Button icon={SvgPlusCircle} prominence="tertiary" disabled />
<Button icon={SvgSliders} prominence="tertiary" disabled />
<Button icon={SvgHourglass} variant="select" disabled />
</div>
{/* Right side controls */}
<div className="flex flex-row items-center gap-1">
<OpenButton icon={OpenAISVG} foldable disabled>
GPT-4o
</OpenButton>
<Button icon={SvgArrowUp} disabled />
</div>
</div>
</div>
{/* Fade overlay */}
<div className="absolute inset-0 rounded-16 backdrop-blur-sm bg-background-neutral-00/50" />
{/* CTA button */}
<div className="absolute inset-0 flex items-center justify-center">
<Button prominence="secondary" icon={SvgEditBig} href="/app">
Start New Session
</Button>
</div>
</div>
);
}

View File

@@ -0,0 +1,168 @@
"use client";
import { useState, useEffect, useCallback } from "react";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import "@/app/app/message/custom-code-styles.css";
import Button from "@/refresh-components/buttons/Button";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import Text from "@/refresh-components/texts/Text";
import { SvgFileText } from "@opal/icons";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
import { Section } from "@/layouts/general-layouts";
import { getCodeLanguage } from "@/lib/languages";
export interface CodeViewProps {
presentingDocument: MinimalOnyxDocument;
onClose: () => void;
}
export default function CodeViewModal({
presentingDocument,
onClose,
}: CodeViewProps) {
const [fileContent, setFileContent] = useState("");
const [fileUrl, setFileUrl] = useState("");
const [fileName, setFileName] = useState("");
const [isLoading, setIsLoading] = useState(true);
const [loadError, setLoadError] = useState<string | null>(null);
const language =
getCodeLanguage(presentingDocument.semantic_identifier || "") ||
"plaintext";
const fetchFile = useCallback(
async (signal?: AbortSignal) => {
setIsLoading(true);
setLoadError(null);
setFileContent("");
const fileIdLocal =
presentingDocument.document_id.split("__")[1] ||
presentingDocument.document_id;
try {
const response = await fetch(
`/api/chat/file/${encodeURIComponent(fileIdLocal)}`,
{
method: "GET",
signal,
cache: "force-cache",
}
);
if (!response.ok) {
setLoadError("Failed to load document.");
return;
}
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
setFileUrl((prev) => {
if (prev) {
window.URL.revokeObjectURL(prev);
}
return url;
});
const originalFileName =
presentingDocument.semantic_identifier || "document";
setFileName(originalFileName);
const text = await blob.text();
setFileContent(text);
} catch (error) {
if (signal?.aborted) {
return;
}
setLoadError("Failed to load document.");
} finally {
if (!signal?.aborted) {
setIsLoading(false);
}
}
},
[presentingDocument]
);
useEffect(() => {
const controller = new AbortController();
fetchFile(controller.signal);
return () => {
controller.abort();
};
}, [fetchFile]);
useEffect(() => {
return () => {
if (fileUrl) {
window.URL.revokeObjectURL(fileUrl);
}
};
}, [fileUrl]);
const handleDownload = () => {
const link = document.createElement("a");
link.href = fileUrl;
link.download = fileName || presentingDocument.document_id;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
};
return (
<Modal
open
onOpenChange={(open) => {
if (!open) {
onClose();
}
}}
>
<Modal.Content
width="md"
height="fit"
preventAccidentalClose={false}
onOpenAutoFocus={(e) => e.preventDefault()}
>
<Modal.Header
icon={SvgFileText}
title={fileName || "Code"}
onClose={onClose}
/>
<Modal.Body padding={0} gap={0}>
<Section padding={0} gap={0}>
{isLoading ? (
<Section>
<SimpleLoader className="h-8 w-8" />
</Section>
) : loadError ? (
<Section padding={1}>
<Text text03 mainUiBody>
{loadError}
</Text>
</Section>
) : (
<ScrollIndicatorDiv
className="flex-1 min-h-0 w-full"
variant="shadow"
>
<MinimalMarkdown
content={`\`\`\`${language}\n${fileContent}\n\`\`\``}
className="w-full h-full break-words"
/>
</ScrollIndicatorDiv>
)}
</Section>
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
submit={<Button onClick={handleDownload}>Download File</Button>}
/>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,255 +0,0 @@
"use client";
import { useState } from "react";
import { cn } from "@/lib/utils";
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
import { toast } from "@/hooks/useToast";
import { useChatSessionStore } from "@/app/app/stores/useChatSessionStore";
import { copyAll } from "@/app/app/message/copyingUtils";
import { Section } from "@/layouts/general-layouts";
import Modal from "@/refresh-components/Modal";
import Button from "@/refresh-components/buttons/Button";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { SvgLink, SvgShare, SvgUsers } from "@opal/icons";
import SvgCheck from "@opal/icons/check";
import SvgLock from "@opal/icons/lock";
import type { IconProps } from "@opal/types";
import useChatSessions from "@/hooks/useChatSessions";
function buildShareLink(chatSessionId: string) {
const baseUrl = `${window.location.protocol}//${window.location.host}`;
return `${baseUrl}/app/shared/${chatSessionId}`;
}
async function generateShareLink(chatSessionId: string) {
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ sharing_status: "public" }),
});
if (response.ok) {
return buildShareLink(chatSessionId);
}
return null;
}
async function deleteShareLink(chatSessionId: string) {
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ sharing_status: "private" }),
});
return response.ok;
}
interface PrivacyOptionProps {
icon: React.FunctionComponent<IconProps>;
title: string;
description: string;
selected: boolean;
onClick: () => void;
ariaLabel?: string;
}
function PrivacyOption({
icon: Icon,
title,
description,
selected,
onClick,
ariaLabel,
}: PrivacyOptionProps) {
return (
<div
className={cn(
"p-1.5 rounded-08 cursor-pointer ",
selected ? "bg-background-tint-00" : "bg-transparent",
"hover:bg-background-tint-02"
)}
onClick={onClick}
aria-label={ariaLabel}
>
<div className="flex flex-row gap-1 items-center">
<div className="flex w-5 p-[2px] self-stretch justify-center">
<Icon
size={16}
className={cn(selected ? "stroke-text-05" : "stroke-text-03")}
/>
</div>
<div className="flex flex-col flex-1 px-0.5">
<Text mainUiBody text05={selected} text03={!selected}>
{title}
</Text>
<Text secondaryBody text03>
{description}
</Text>
</div>
{selected && (
<div className="flex w-5 self-stretch justify-center">
<SvgCheck size={16} className="stroke-action-link-05" />
</div>
)}
</div>
</div>
);
}
interface ShareChatSessionModalProps {
chatSession: ChatSession;
onClose: () => void;
}
export default function ShareChatSessionModal({
chatSession,
onClose,
}: ShareChatSessionModalProps) {
const isCurrentlyPublic =
chatSession.shared_status === ChatSessionSharedStatus.Public;
const [selectedPrivacy, setSelectedPrivacy] = useState<"private" | "public">(
isCurrentlyPublic ? "public" : "private"
);
const [shareLink, setShareLink] = useState<string>(
isCurrentlyPublic ? buildShareLink(chatSession.id) : ""
);
const [isLoading, setIsLoading] = useState(false);
const updateCurrentChatSessionSharedStatus = useChatSessionStore(
(state) => state.updateCurrentChatSessionSharedStatus
);
const { refreshChatSessions } = useChatSessions();
const wantsPublic = selectedPrivacy === "public";
const isShared = shareLink && selectedPrivacy === "public";
let submitButtonText = "Done";
if (wantsPublic && !isCurrentlyPublic && !shareLink) {
submitButtonText = "Create Share Link";
} else if (!wantsPublic && isCurrentlyPublic) {
submitButtonText = "Make Private";
} else if (isShared) {
submitButtonText = "Copy Link";
}
async function handleSubmit() {
setIsLoading(true);
try {
if (wantsPublic && !isCurrentlyPublic && !shareLink) {
const link = await generateShareLink(chatSession.id);
if (link) {
setShareLink(link);
updateCurrentChatSessionSharedStatus(ChatSessionSharedStatus.Public);
await refreshChatSessions();
copyAll(link);
toast.success("Share link copied to clipboard!");
} else {
toast.error("Failed to generate share link");
}
} else if (!wantsPublic && isCurrentlyPublic) {
const success = await deleteShareLink(chatSession.id);
if (success) {
setShareLink("");
updateCurrentChatSessionSharedStatus(ChatSessionSharedStatus.Private);
await refreshChatSessions();
toast.success("Chat is now private");
onClose();
} else {
toast.error("Failed to make chat private");
}
} else if (wantsPublic && shareLink) {
copyAll(shareLink);
toast.success("Share link copied to clipboard!");
} else {
onClose();
}
} catch (e) {
console.error(e);
toast.error("An error occurred");
} finally {
setIsLoading(false);
}
}
return (
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
<Modal.Content width="sm">
<Modal.Header
icon={SvgShare}
title={isShared ? "Chat shared" : "Share this chat"}
description="All existing and future messages in this chat will be shared."
onClose={onClose}
/>
<Modal.Body twoTone>
<Section
justifyContent="start"
alignItems="stretch"
gap={1}
height="auto"
>
<Section
justifyContent="start"
alignItems="stretch"
height="auto"
gap={0.12}
>
<PrivacyOption
icon={SvgLock}
title="Private"
description="Only you have access to this chat."
selected={selectedPrivacy === "private"}
onClick={() => setSelectedPrivacy("private")}
ariaLabel="share-modal-option-private"
/>
<PrivacyOption
icon={SvgUsers}
title="Your Organization"
description="Anyone in your organization can view this chat."
selected={selectedPrivacy === "public"}
onClick={() => setSelectedPrivacy("public")}
ariaLabel="share-modal-option-public"
/>
</Section>
{isShared && (
<div aria-label="share-modal-link-input">
<InputTypeIn
readOnly
value={shareLink}
rightSection={
<CopyIconButton
getCopyText={() => shareLink}
tooltip="Copy link"
size="sm"
aria-label="share-modal-copy-link"
/>
}
/>
</div>
)}
</Section>
</Modal.Body>
<Modal.Footer>
{!isShared && (
<Button secondary onClick={onClose} aria-label="share-modal-cancel">
Cancel
</Button>
)}
<Button
onClick={handleSubmit}
disabled={isLoading}
leftIcon={isShared ? SvgLink : undefined}
className={isShared ? "w-full" : undefined}
aria-label="share-modal-submit"
>
{submitButtonText}
</Button>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -18,7 +18,7 @@ import {
import { useProjectsContext } from "@/providers/ProjectsContext";
import MoveCustomAgentChatModal from "@/components/modals/MoveCustomAgentChatModal";
import { UNNAMED_CHAT } from "@/lib/constants";
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
import ShareChatSessionModal from "@/app/app/components/modal/ShareChatSessionModal";
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
import IconButton from "@/refresh-components/buttons/IconButton";
import { Button as OpalButton } from "@opal/components";

View File

@@ -1,6 +1,5 @@
import { test, expect } from "@playwright/test";
import type { Page } from "@playwright/test";
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
test.use({ storageState: "admin_auth.json" });
@@ -163,10 +162,16 @@ async function verifyAdminPageNavigation(
}
}
const THEMES = ["light", "dark"] as const;
for (const theme of THEMES) {
test.describe(`Admin pages (${theme} mode)`, () => {
// Inject the theme into localStorage before every navigation so
// next-themes picks it up on first render.
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
await page.addInitScript((t: string) => {
localStorage.setItem("theme", t);
}, theme);
});
for (const snapshot of ADMIN_PAGES) {

View File

@@ -141,7 +141,7 @@ test.describe("Web Content Provider Configuration", () => {
await modalDialog
.getByRole("button", { name: "Connect", exact: true })
.click();
await expect(modalDialog).not.toBeVisible({ timeout: 60000 });
await expect(modalDialog).not.toBeVisible({ timeout: 30000 });
await page.waitForLoadState("networkidle");
} else if (await setDefaultButton.isVisible()) {
// If already configured but not active, set as default

View File

@@ -1,7 +1,7 @@
import { test, expect } from "@playwright/test";
import {
TEST_ADMIN_CREDENTIALS,
workerUserCredentials,
TEST_USER_CREDENTIALS,
} from "@tests/e2e/constants";
import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
@@ -48,7 +48,7 @@ test.describe("Login flow", () => {
await page.goto("/auth/login");
await page.waitForLoadState("networkidle");
await page.getByTestId("email").fill(workerUserCredentials(0).email);
await page.getByTestId("email").fill(TEST_USER_CREDENTIALS.email);
await page.getByTestId("password").fill("WrongPassword123!");
await page.getByRole("button", { name: "Sign In" }).click();

View File

@@ -1,10 +1,10 @@
import { test, expect, Page, Locator } from "@playwright/test";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
test.use({ storageState: "admin_auth.json" });
// Test data storage
const TEST_PREFIX = "E2E-CMD";
const TEST_PREFIX = `E2E-CMD-${Date.now()}`;
let chatSessionIds: string[] = [];
let projectIds: number[] = [];
@@ -12,9 +12,17 @@ let projectIds: number[] = [];
* Helper to get the command menu dialog locator (using the content wrapper)
*/
function getCommandMenuContent(page: Page): Locator {
// Use DialogPrimitive.Content which has role="dialog" and contains the visually-hidden title
return page.locator('[role="dialog"]:has([data-command-menu-list])');
}
/**
* Helper to get the command menu list locator
*/
function getCommandMenuList(page: Page): Locator {
return page.locator("[data-command-menu-list]");
}
/**
* Helper to open the command menu and return a scoped locator
*/
@@ -28,20 +36,25 @@ async function openCommandMenu(page: Page): Promise<Locator> {
}
test.describe("Chat Search Command Menu", () => {
test.beforeAll(async ({ browser }, workerInfo) => {
const context = await browser.newContext();
// Create all test data ONCE before all tests
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
const page = await context.newPage();
await loginAsWorkerUser(page, workerInfo.workerIndex);
const client = new OnyxApiClient(page.request);
// Navigate to app to establish session
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Create 5 chat sessions
for (let i = 1; i <= 5; i++) {
const id = await client.createChatSession(`${TEST_PREFIX} Chat ${i}`);
chatSessionIds.push(id);
}
// Create 4 projects
for (let i = 1; i <= 4; i++) {
const id = await client.createProject(`${TEST_PREFIX} Project ${i}`);
projectIds.push(id);
@@ -50,18 +63,24 @@ test.describe("Chat Search Command Menu", () => {
await context.close();
});
test.afterAll(async ({ browser }, workerInfo) => {
const context = await browser.newContext();
// Cleanup all test data ONCE after all tests
test.afterAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
const page = await context.newPage();
await loginAsWorkerUser(page, workerInfo.workerIndex);
const client = new OnyxApiClient(page.request);
// Navigate to app to establish session
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Delete chat sessions
for (const id of chatSessionIds) {
await client.deleteChatSession(id);
}
// Delete projects
for (const id of projectIds) {
await client.deleteProject(id);
}
@@ -69,269 +88,472 @@ test.describe("Chat Search Command Menu", () => {
await context.close();
});
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
await page.goto("/app");
await page.waitForLoadState("networkidle");
});
test.describe("Menu Opening", () => {
test("Opens when clicking sidebar search trigger", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
// -- Opening --
const dialog = await openCommandMenu(page);
test("Opens with search input, New Session action, and correct positioning", async ({
page,
}) => {
const dialog = await openCommandMenu(page);
await expect(
dialog.getByPlaceholder("Search chat sessions, projects...")
).toBeFocused();
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
await expectScreenshot(page, { name: "command-menu-default-open" });
});
// -- Preview limits --
test("Shows at most 4 chats and 3 projects in preview", async ({ page }) => {
const dialog = await openCommandMenu(page);
const chatCount = await dialog
.locator('[data-command-item^="chat-"]')
.count();
expect(chatCount).toBeLessThanOrEqual(4);
const projectCount = await dialog
.locator('[data-command-item^="project-"]')
.count();
expect(projectCount).toBeLessThanOrEqual(3);
});
test('Shows "Recent Sessions", "Projects" filters and "New Project" action', async ({
page,
}) => {
const dialog = await openCommandMenu(page);
await expect(
dialog.locator('[data-command-item="recent-sessions"]')
).toBeVisible();
await expect(
dialog.locator('[data-command-item="projects"]')
).toBeVisible();
await expect(
dialog.locator('[data-command-item="new-project"]')
).toBeVisible();
});
// -- Filter expansion --
test('"Recent Sessions" filter expands to show all 5 chats', async ({
page,
}) => {
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="recent-sessions"]').click();
await page.waitForTimeout(500);
for (let i = 1; i <= 5; i++) {
await expect(
dialog.locator(`[data-command-item="chat-${chatSessionIds[i - 1]}"]`)
dialog.getByPlaceholder("Search chat sessions, projects...")
).toBeVisible();
}
await expect(dialog.getByText("Sessions")).toBeVisible();
await expectScreenshot(page, { name: "command-menu-sessions-filter" });
});
test('"Projects" filter expands to show all 4 projects', async ({ page }) => {
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="projects"]').click();
await page.waitForTimeout(500);
for (let i = 1; i <= 4; i++) {
// "New Session" action should be visible within the command menu
await expect(
dialog.locator(`[data-command-item="project-${projectIds[i - 1]}"]`)
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
}
});
await expectScreenshot(page, { name: "command-menu-projects-filter" });
test("Shows search input with placeholder and focus", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await expect(input).toBeVisible();
await expect(input).toBeFocused();
});
test('Shows "New Session" action when no search term', async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Use data-command-item attribute to target the action
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
});
});
test("Filter chip X removes filter and returns to all", async ({ page }) => {
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="recent-sessions"]').click();
await expect(dialog.getByText("Sessions")).toBeVisible();
test.describe("Preview Display", () => {
test("Shows at most 4 chat sessions (PREVIEW_CHATS_LIMIT)", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
await dialog.locator('button[aria-label="Remove Sessions filter"]').click();
const dialog = await openCommandMenu(page);
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
// Should show at most 4 chat sessions in preview mode
const chatItems = dialog.locator('[data-command-item^="chat-"]');
const chatCount = await chatItems.count();
// In "all" filter with no search, should show max 4 chats
expect(chatCount).toBeLessThanOrEqual(4);
});
test("Shows at most 3 projects (PREVIEW_PROJECTS_LIMIT)", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Should show at most 3 projects in preview mode
const projectItems = dialog.locator('[data-command-item^="project-"]');
const projectCount = await projectItems.count();
// In "all" filter with no search, should show max 3 projects
expect(projectCount).toBeLessThanOrEqual(3);
});
test('Shows "Recent Sessions" filter', async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// The "Recent Sessions" filter should be visible
await expect(
dialog.locator('[data-command-item="recent-sessions"]')
).toBeVisible();
});
test('Shows "Projects" filter', async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// The "Projects" filter should be visible
await expect(
dialog.locator('[data-command-item="projects"]')
).toBeVisible();
});
test('Shows "New Project" action', async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
await expect(
dialog.locator('[data-command-item="new-project"]')
).toBeVisible();
});
});
test("Backspace on empty input removes active filter", async ({ page }) => {
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="recent-sessions"]').click();
await expect(dialog.getByText("Sessions")).toBeVisible();
test.describe("Filter Expansion", () => {
test('Click "Recent Sessions" filter shows all 5 chats', async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.focus();
await page.keyboard.press("Backspace");
const dialog = await openCommandMenu(page);
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
// Click on Recent Sessions filter to expand
await dialog.locator('[data-command-item="recent-sessions"]').click();
// Wait for the filter to be applied and all chats to load
await page.waitForTimeout(500);
// Should now show all 5 test chats - use data-command-item to find them
for (let i = 1; i <= 5; i++) {
await expect(
dialog.locator(`[data-command-item="chat-${chatSessionIds[i - 1]}"]`)
).toBeVisible();
}
});
test('Filter chip "Sessions" appears in header when chats filter is active', async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Click on Recent Sessions filter
await dialog.locator('[data-command-item="recent-sessions"]').click();
// The filter chip should appear (look for the editable tag with "Sessions")
await expect(dialog.getByText("Sessions")).toBeVisible();
});
test('Click "Projects" filter shows all 4 projects', async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Click on Projects filter to expand
await dialog.locator('[data-command-item="projects"]').click();
// Wait for the filter to be applied
await page.waitForTimeout(500);
// Should now show all 4 test projects - use data-command-item to find them
for (let i = 1; i <= 4; i++) {
await expect(
dialog.locator(`[data-command-item="project-${projectIds[i - 1]}"]`)
).toBeVisible();
}
});
test("Clicking filter chip X removes filter and returns to 'all'", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Click on Recent Sessions filter
await dialog.locator('[data-command-item="recent-sessions"]').click();
// Wait for the filter to be applied
await expect(dialog.getByText("Sessions")).toBeVisible();
// Click the X on the filter tag to remove it (aria-label is "Remove Sessions filter")
await dialog
.locator('button[aria-label="Remove Sessions filter"]')
.click();
// Should be back to "all" view - "New Session" action should be visible again
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
});
test("Backspace on empty input removes active filter", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Click on Recent Sessions filter
await dialog.locator('[data-command-item="recent-sessions"]').click();
// Wait for the filter to be applied
await expect(dialog.getByText("Sessions")).toBeVisible();
// Ensure focus is on the input field
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.focus();
// Press backspace on empty input to remove filter
await page.keyboard.press("Backspace");
// Should be back to "all" view
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
});
test("Backspace on empty input with no filter closes menu", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
await openCommandMenu(page);
// Press backspace on empty input (no filter active)
await page.keyboard.press("Backspace");
// Menu should close
await expect(getCommandMenuContent(page)).not.toBeVisible();
});
});
test("Backspace on empty input with no filter closes menu", async ({
page,
}) => {
await openCommandMenu(page);
await page.keyboard.press("Backspace");
await expect(getCommandMenuContent(page)).not.toBeVisible();
test.describe("Search Filtering", () => {
test("Search finds matching chat session", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill(`${TEST_PREFIX} Chat 3`);
// Wait for search results
await page.waitForTimeout(500);
// Should show the matching chat - use specific data-command-item
await expect(
dialog.locator(`[data-command-item="chat-${chatSessionIds[2]}"]`)
).toBeVisible();
});
test("Search finds matching project", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill(`${TEST_PREFIX} Project 2`);
// Wait for search results
await page.waitForTimeout(500);
// Should show the matching project - use specific data-command-item
await expect(
dialog.locator(`[data-command-item="project-${projectIds[1]}"]`)
).toBeVisible();
});
test('Search shows "Create New Project" action with typed name', async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill("my custom project name");
// Should show create project action with the search term
await expect(
dialog.locator('[data-command-item="create-project-with-name"]')
).toBeVisible();
});
test('Search with no results shows "No results found"', async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill("xyz123nonexistent9999");
// Wait for search to complete
await page.waitForTimeout(500);
// Should show no results message or the "No more results" separator
// The component shows "No results found" when there are no matches
const noResults = dialog.getByText("No results found");
const noMore = dialog.getByText("No more results");
await expect(noResults.or(noMore)).toBeVisible();
});
});
// -- Search --
test.describe("Navigation Actions", () => {
test('"New Session" click navigates to /app', async ({ page }) => {
await page.goto("/chat");
await page.waitForLoadState("networkidle");
test("Search finds matching chat session", async ({ page }) => {
const dialog = await openCommandMenu(page);
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill(`${TEST_PREFIX} Chat 3`);
await page.waitForTimeout(500);
// Click New Session action
await dialog.locator('[data-command-item="new-session"]').click();
await expect(
dialog.locator(`[data-command-item="chat-${chatSessionIds[2]}"]`)
).toBeVisible();
// Should navigate to /app
await page.waitForURL(/\/app/);
expect(page.url()).toContain("/app");
});
await expectScreenshot(page, { name: "command-menu-search-results" });
test("Click chat session navigates to /chat?chatId={id}", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Search for a specific chat
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill(`${TEST_PREFIX} Chat 1`);
// Wait for search results
await page.waitForTimeout(500);
// Click on the chat using data-command-item
await dialog
.locator(`[data-command-item="chat-${chatSessionIds[0]}"]`)
.click();
// Should navigate to the chat URL
await page.waitForURL(/chatId=/);
expect(page.url()).toContain(`chatId=${chatSessionIds[0]}`);
});
test("Click project navigates to /chat?projectId={id}", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Search for a specific project
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill(`${TEST_PREFIX} Project 1`);
// Wait for search results
await page.waitForTimeout(500);
// Click on the project using data-command-item
await dialog
.locator(`[data-command-item="project-${projectIds[0]}"]`)
.click();
// Should navigate to the project URL
await page.waitForURL(/projectId=/);
expect(page.url()).toContain(`projectId=${projectIds[0]}`);
});
test('"New Project" click opens create project modal', async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const dialog = await openCommandMenu(page);
// Click New Project action
await dialog.locator('[data-command-item="new-project"]').click();
// Should open the create project modal
await expect(page.getByText("Create New Project")).toBeVisible();
});
});
test("Search finds matching project", async ({ page }) => {
const dialog = await openCommandMenu(page);
test.describe("Menu State", () => {
test("Menu closes after selecting an action/item", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill(`${TEST_PREFIX} Project 2`);
await page.waitForTimeout(500);
const dialog = await openCommandMenu(page);
await expect(
dialog.locator(`[data-command-item="project-${projectIds[1]}"]`)
).toBeVisible();
});
// Select New Session
await dialog.locator('[data-command-item="new-session"]').click();
test('Search shows "Create New Project" action with typed name', async ({
page,
}) => {
const dialog = await openCommandMenu(page);
// Menu should close
await expect(getCommandMenuContent(page)).not.toBeVisible();
});
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill("my custom project name");
test("Menu state resets when reopened (search cleared, filter reset)", async ({
page,
}) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
await expect(
dialog.locator('[data-command-item="create-project-with-name"]')
).toBeVisible();
});
// Open menu and apply a filter first (filter is only visible when search is empty)
let dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="recent-sessions"]').click();
test("Search with no results shows empty state", async ({ page }) => {
const dialog = await openCommandMenu(page);
// Wait for the filter to be applied
await expect(dialog.getByText("Sessions")).toBeVisible();
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill("xyz123nonexistent9999");
await page.waitForTimeout(500);
// Now type something in the search
const input = dialog.getByPlaceholder(
"Search chat sessions, projects..."
);
await input.fill("test query");
const noResults = dialog.getByText("No results found");
const noMore = dialog.getByText("No more results");
await expect(noResults.or(noMore)).toBeVisible();
// Close with Escape
await page.keyboard.press("Escape");
await expectScreenshot(page, { name: "command-menu-no-results" });
});
// Wait for menu to close
await expect(getCommandMenuContent(page)).not.toBeVisible();
// -- Navigation --
// Reopen
dialog = await openCommandMenu(page);
test('"New Session" navigates to /app', async ({ page }) => {
// Start from /chat so navigation is observable
await page.goto("/chat");
await page.waitForLoadState("networkidle");
// Search input should be empty
await expect(
dialog.getByPlaceholder("Search chat sessions, projects...")
).toHaveValue("");
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="new-session"]').click();
// Should be back to "all" view with "New Session" action visible
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
});
await page.waitForURL(/\/app/);
expect(page.url()).toContain("/app");
});
test("Escape closes menu", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
test("Clicking a chat session navigates to its URL", async ({ page }) => {
const dialog = await openCommandMenu(page);
await openCommandMenu(page);
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill(`${TEST_PREFIX} Chat 1`);
await page.waitForTimeout(500);
// Press Escape
await page.keyboard.press("Escape");
await dialog
.locator(`[data-command-item="chat-${chatSessionIds[0]}"]`)
.click();
await page.waitForURL(/chatId=/);
expect(page.url()).toContain(`chatId=${chatSessionIds[0]}`);
});
test("Clicking a project navigates to its URL", async ({ page }) => {
const dialog = await openCommandMenu(page);
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill(`${TEST_PREFIX} Project 1`);
await page.waitForTimeout(500);
await dialog
.locator(`[data-command-item="project-${projectIds[0]}"]`)
.click();
await page.waitForURL(/projectId=/);
expect(page.url()).toContain(`projectId=${projectIds[0]}`);
});
test('"New Project" opens create project modal', async ({ page }) => {
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="new-project"]').click();
await expect(page.getByText("Create New Project")).toBeVisible();
});
// -- Menu state --
test("Menu closes after selecting an item", async ({ page }) => {
const dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="new-session"]').click();
await expect(getCommandMenuContent(page)).not.toBeVisible();
});
test("Escape closes menu", async ({ page }) => {
await openCommandMenu(page);
await page.keyboard.press("Escape");
await expect(getCommandMenuContent(page)).not.toBeVisible();
});
test("Menu state resets when reopened", async ({ page }) => {
let dialog = await openCommandMenu(page);
await dialog.locator('[data-command-item="recent-sessions"]').click();
await expect(dialog.getByText("Sessions")).toBeVisible();
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
await input.fill("test query");
await page.keyboard.press("Escape");
await expect(getCommandMenuContent(page)).not.toBeVisible();
dialog = await openCommandMenu(page);
await expect(
dialog.getByPlaceholder("Search chat sessions, projects...")
).toHaveValue("");
await expect(
dialog.locator('[data-command-item="new-session"]')
).toBeVisible();
// Menu should close
await expect(getCommandMenuContent(page)).not.toBeVisible();
});
});
});

View File

@@ -1,709 +0,0 @@
import { expect, Page, test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
import { sendMessage } from "@tests/e2e/utils/chatActions";
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
const SHORT_USER_MESSAGE = "What is Onyx?";
const LONG_USER_MESSAGE = `I've been evaluating several enterprise search and AI platforms for our organization, and I have a number of detailed questions about Onyx that I'd like to understand before we make a decision.
First, can you explain how Onyx handles document indexing across multiple data sources? We currently use Confluence, Google Drive, Slack, and GitHub, and we need to ensure that all of these can be indexed simultaneously without performance degradation.
Second, I'm interested in understanding the security model. Specifically, how does Onyx handle document-level permissions when syncing from sources that have their own ACL systems? Does it respect the original source permissions, or does it create its own permission layer?
Third, we have a requirement for real-time or near-real-time indexing. What is the typical latency between a document being updated in a source system and it becoming searchable in Onyx?
Finally, could you walk me through the architecture of the AI chat system? How does it decide which documents to reference when answering a question, and how does it handle cases where the retrieved documents might contain conflicting information?`;
const SHORT_AI_RESPONSE =
"Onyx is an open-source AI-powered enterprise search platform that connects to your company's documents, apps, and people.";
const LONG_AI_RESPONSE = `Onyx is an open-source Gen-AI and Enterprise Search platform designed to connect to your company's documents, applications, and people. Let me address each of your questions in detail.
## Document Indexing
Onyx uses a **connector-based architecture** where each data source has a dedicated connector. These connectors run as background workers and can index simultaneously without interfering with each other. The supported connectors include:
- **Confluence** — Full page and space indexing with attachment support
- **Google Drive** — File and folder indexing with shared drive support
- **Slack** — Channel message indexing with thread support
- **GitHub** — Repository, issue, and pull request indexing
Each connector runs on its own schedule and can be configured independently for polling frequency.
## Security Model
Onyx implements a **document-level permission system** that syncs with source ACLs. When documents are indexed, their permissions are preserved:
\`\`\`
Source Permission → Onyx ACL Sync → Query-time Filtering
\`\`\`
This means that when a user searches, they only see documents they have access to in the original source system. The permission sync runs periodically to stay up to date.
## Indexing Latency
The typical indexing latency depends on your configuration:
1. **Polling mode**: Documents are picked up on the next polling cycle (configurable, default 10 minutes)
2. **Webhook mode**: Near real-time, typically under 30 seconds
3. **Manual trigger**: Immediate indexing on demand
## AI Chat Architecture
The chat system uses a **Retrieval-Augmented Generation (RAG)** pipeline:
1. User query is analyzed and expanded
2. Relevant documents are retrieved from the vector database (Vespa)
3. Documents are ranked and filtered by relevance and permissions
4. The LLM generates a response grounded in the retrieved documents
5. Citations are attached to specific claims in the response
When documents contain conflicting information, the system presents the most relevant and recent information first, and includes citations so users can verify the source material themselves.`;
const MARKDOWN_AI_RESPONSE = `Here's a quick overview with various formatting:
### Key Features
| Feature | Status | Notes |
|---------|--------|-------|
| Enterprise Search | ✅ Available | Full-text and semantic |
| AI Chat | ✅ Available | Multi-model support |
| Connectors | ✅ Available | 30+ integrations |
| Permissions | ✅ Available | Source ACL sync |
### Code Example
\`\`\`python
from onyx import OnyxClient
client = OnyxClient(api_key="your-key")
results = client.search("quarterly revenue report")
for doc in results:
print(f"{doc.title}: {doc.score:.2f}")
\`\`\`
> **Note**: Onyx supports both cloud and self-hosted deployments. The self-hosted option gives you full control over your data.
Key benefits include:
- **Privacy**: Your data stays within your infrastructure
- **Flexibility**: Connect any data source via custom connectors
- **Extensibility**: Open-source codebase with active community`;
interface MockDocument {
document_id: string;
semantic_identifier: string;
link: string;
source_type: string;
blurb: string;
is_internet: boolean;
}
interface SearchMockOptions {
content: string;
queries: string[];
documents: MockDocument[];
/** Maps citation number -> document_id */
citations: Record<number, string>;
isInternetSearch?: boolean;
}
let turnCounter = 0;
function buildMockStream(content: string): string {
turnCounter += 1;
const userMessageId = turnCounter * 100 + 1;
const assistantMessageId = turnCounter * 100 + 2;
const packets = [
{
user_message_id: userMessageId,
reserved_assistant_message_id: assistantMessageId,
},
{
placement: { turn_index: 0, tab_index: 0 },
obj: {
type: "message_start",
id: `mock-${assistantMessageId}`,
content,
final_documents: null,
},
},
{
placement: { turn_index: 0, tab_index: 0 },
obj: { type: "stop", stop_reason: "finished" },
},
{
message_id: assistantMessageId,
citations: {},
files: [],
},
];
return `${packets.map((p) => JSON.stringify(p)).join("\n")}\n`;
}
function buildMockSearchStream(options: SearchMockOptions): string {
turnCounter += 1;
const userMessageId = turnCounter * 100 + 1;
const assistantMessageId = turnCounter * 100 + 2;
const fullDocs = options.documents.map((doc) => ({
...doc,
boost: 0,
hidden: false,
score: 0.95,
chunk_ind: 0,
match_highlights: [],
metadata: {},
updated_at: null,
}));
// Turn 0: search tool
// Turn 1: answer + citations
const packets: Record<string, unknown>[] = [
{
user_message_id: userMessageId,
reserved_assistant_message_id: assistantMessageId,
},
{
placement: { turn_index: 0, tab_index: 0 },
obj: {
type: "search_tool_start",
...(options.isInternetSearch !== undefined && {
is_internet_search: options.isInternetSearch,
}),
},
},
{
placement: { turn_index: 0, tab_index: 0 },
obj: { type: "search_tool_queries_delta", queries: options.queries },
},
{
placement: { turn_index: 0, tab_index: 0 },
obj: { type: "search_tool_documents_delta", documents: fullDocs },
},
{
placement: { turn_index: 0, tab_index: 0 },
obj: { type: "section_end" },
},
{
placement: { turn_index: 1, tab_index: 0 },
obj: {
type: "message_start",
id: `mock-${assistantMessageId}`,
content: options.content,
final_documents: fullDocs,
},
},
...Object.entries(options.citations).map(([num, docId]) => ({
placement: { turn_index: 1, tab_index: 0 },
obj: {
type: "citation_info",
citation_number: Number(num),
document_id: docId,
},
})),
{
placement: { turn_index: 1, tab_index: 0 },
obj: { type: "stop", stop_reason: "finished" },
},
{
message_id: assistantMessageId,
citations: options.citations,
files: [],
},
];
return `${packets.map((p) => JSON.stringify(p)).join("\n")}\n`;
}
async function openChat(page: Page): Promise<void> {
await page.goto("/app");
await page.waitForLoadState("networkidle");
await page.waitForSelector("#onyx-chat-input-textarea", { timeout: 15000 });
}
async function mockChatEndpoint(
page: Page,
responseContent: string
): Promise<void> {
await page.route("**/api/chat/send-chat-message", async (route) => {
await route.fulfill({
status: 200,
contentType: "text/plain",
body: buildMockStream(responseContent),
});
});
}
async function mockChatEndpointSequence(
page: Page,
responses: string[]
): Promise<void> {
let callIndex = 0;
await page.route("**/api/chat/send-chat-message", async (route) => {
const content =
responses[Math.min(callIndex, responses.length - 1)] ??
responses[responses.length - 1]!;
callIndex += 1;
await route.fulfill({
status: 200,
contentType: "text/plain",
body: buildMockStream(content),
});
});
}
async function screenshotChatContainer(
page: Page,
name: string
): Promise<void> {
const container = page.locator("[data-main-container]");
await expect(container).toBeVisible();
await expectElementScreenshot(container, { name });
}
/**
* Captures two screenshots of the chat container for long-content tests:
* one scrolled to the top and one scrolled to the bottom. Both are captured
* for the current theme, ensuring consistent scroll positions regardless of
* whether the page was just navigated to (top) or just finished streaming (bottom).
*/
async function screenshotChatContainerTopAndBottom(
page: Page,
name: string
): Promise<void> {
const container = page.locator("[data-main-container]");
await expect(container).toBeVisible();
const scrollContainer = page.getByTestId("chat-scroll-container");
await scrollContainer.evaluate((el) => el.scrollTo({ top: 0 }));
await expectElementScreenshot(container, { name: `${name}-top` });
await scrollContainer.evaluate((el) => el.scrollTo({ top: el.scrollHeight }));
await expectElementScreenshot(container, { name: `${name}-bottom` });
}
for (const theme of THEMES) {
test.describe(`Chat Message Rendering (${theme} mode)`, () => {
test.beforeEach(async ({ page }, testInfo) => {
turnCounter = 0;
await page.context().clearCookies();
await setThemeBeforeNavigation(page, theme);
await loginAsWorkerUser(page, testInfo.workerIndex);
});
test.describe("Short Messages", () => {
test("short user message with short AI response renders correctly", async ({
page,
}) => {
await openChat(page);
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
await sendMessage(page, SHORT_USER_MESSAGE);
const userMessage = page.locator("#onyx-human-message").first();
await expect(userMessage).toContainText(SHORT_USER_MESSAGE);
const aiMessage = page.getByTestId("onyx-ai-message").first();
await expect(aiMessage).toContainText("open-source AI-powered");
await screenshotChatContainer(
page,
`chat-short-message-short-response-${theme}`
);
});
});
test.describe("Long Messages", () => {
test("long user message renders without truncation", async ({ page }) => {
await openChat(page);
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
await sendMessage(page, LONG_USER_MESSAGE);
const userMessage = page.locator("#onyx-human-message").first();
await expect(userMessage).toContainText("document indexing");
await expect(userMessage).toContainText("security model");
await expect(userMessage).toContainText("real-time or near-real-time");
await expect(userMessage).toContainText("architecture of the AI chat");
await screenshotChatContainer(
page,
`chat-long-user-message-short-response-${theme}`
);
});
test("long AI response with markdown renders correctly", async ({
page,
}) => {
await openChat(page);
await mockChatEndpoint(page, LONG_AI_RESPONSE);
await sendMessage(page, SHORT_USER_MESSAGE);
const aiMessage = page.getByTestId("onyx-ai-message").first();
await expect(aiMessage).toContainText("Document Indexing");
await expect(aiMessage).toContainText("Security Model");
await expect(aiMessage).toContainText("Indexing Latency");
await expect(aiMessage).toContainText("AI Chat Architecture");
await screenshotChatContainerTopAndBottom(
page,
`chat-short-message-long-response-${theme}`
);
});
test("long user message with long AI response renders correctly", async ({
page,
}) => {
await openChat(page);
await mockChatEndpoint(page, LONG_AI_RESPONSE);
await sendMessage(page, LONG_USER_MESSAGE);
const userMessage = page.locator("#onyx-human-message").first();
await expect(userMessage).toContainText("document indexing");
const aiMessage = page.getByTestId("onyx-ai-message").first();
await expect(aiMessage).toContainText("Retrieval-Augmented Generation");
await screenshotChatContainerTopAndBottom(
page,
`chat-long-message-long-response-${theme}`
);
});
});
test.describe("Markdown and Code Rendering", () => {
test("AI response with tables and code blocks renders correctly", async ({
page,
}) => {
await openChat(page);
await mockChatEndpoint(page, MARKDOWN_AI_RESPONSE);
await sendMessage(page, "Give me an overview of Onyx features");
const aiMessage = page.getByTestId("onyx-ai-message").first();
await expect(aiMessage).toContainText("Key Features");
await expect(aiMessage).toContainText("OnyxClient");
await expect(aiMessage).toContainText("Privacy");
await screenshotChatContainer(
page,
`chat-markdown-code-response-${theme}`
);
});
});
test.describe("Multi-Turn Conversation", () => {
test("multi-turn conversation renders all messages correctly", async ({
page,
}) => {
await openChat(page);
const responses = [
SHORT_AI_RESPONSE,
"Yes, Onyx supports over 30 data source connectors including Confluence, Google Drive, Slack, GitHub, Jira, Notion, and many more.",
"To get started, you can deploy Onyx using Docker Compose with a single command. The setup takes about 5 minutes.",
];
await mockChatEndpointSequence(page, responses);
await sendMessage(page, SHORT_USER_MESSAGE);
await expect(page.getByTestId("onyx-ai-message").first()).toContainText(
"open-source AI-powered"
);
await sendMessage(page, "What connectors does it support?");
await expect(page.getByTestId("onyx-ai-message")).toHaveCount(2, {
timeout: 30000,
});
await sendMessage(page, "How do I get started?");
await expect(page.getByTestId("onyx-ai-message")).toHaveCount(3, {
timeout: 30000,
});
const userMessages = page.locator("#onyx-human-message");
await expect(userMessages).toHaveCount(3);
await screenshotChatContainerTopAndBottom(
page,
`chat-multi-turn-conversation-${theme}`
);
});
test("multi-turn with mixed message lengths renders correctly", async ({
page,
}) => {
await openChat(page);
const responses = [LONG_AI_RESPONSE, SHORT_AI_RESPONSE];
await mockChatEndpointSequence(page, responses);
await sendMessage(page, LONG_USER_MESSAGE);
await expect(page.getByTestId("onyx-ai-message").first()).toContainText(
"Document Indexing"
);
await sendMessage(page, SHORT_USER_MESSAGE);
await expect(page.getByTestId("onyx-ai-message")).toHaveCount(2, {
timeout: 30000,
});
await screenshotChatContainerTopAndBottom(
page,
`chat-multi-turn-mixed-lengths-${theme}`
);
});
});
test.describe("Web Search with Citations", () => {
const TOOLBAR_BUTTONS = [
"AgentMessage/copy-button",
"AgentMessage/like-button",
"AgentMessage/dislike-button",
] as const;
async function screenshotToolbarButtonHoverStates(
page: Page,
namePrefix: string
): Promise<void> {
const aiMessage = page.getByTestId("onyx-ai-message").first();
const toolbar = aiMessage.getByTestId("AgentMessage/toolbar");
await expect(toolbar).toBeVisible({ timeout: 10000 });
for (const buttonTestId of TOOLBAR_BUTTONS) {
const button = aiMessage.getByTestId(buttonTestId);
await button.hover();
const buttonSlug = buttonTestId.split("/")[1];
await expectElementScreenshot(toolbar, {
name: `${namePrefix}-toolbar-${buttonSlug}-hover-${theme}`,
});
}
// Sources tag is located by role+name since SourceTag has no testid.
const sourcesButton = toolbar.getByRole("button", { name: "Sources" });
if (await sourcesButton.isVisible()) {
await sourcesButton.hover();
await expectElementScreenshot(toolbar, {
name: `${namePrefix}-toolbar-sources-hover-${theme}`,
});
}
// LLMPopover trigger is only rendered when the regenerate action is
// available (requires onRegenerate + parentMessage + llmManager props).
const llmTrigger = aiMessage.getByTestId("llm-popover-trigger");
if (await llmTrigger.isVisible()) {
await llmTrigger.hover();
await expectElementScreenshot(toolbar, {
name: `${namePrefix}-toolbar-llm-popover-hover-${theme}`,
});
}
}
const WEB_SEARCH_DOCUMENTS: MockDocument[] = [
{
document_id: "web-doc-1",
semantic_identifier: "Onyx Documentation - Getting Started",
link: "https://docs.onyx.app/getting-started",
source_type: "web",
blurb:
"Onyx is an open-source enterprise search and AI platform. Deploy in minutes with Docker Compose.",
is_internet: true,
},
{
document_id: "web-doc-2",
semantic_identifier: "Onyx GitHub Repository",
link: "https://github.com/onyx-dot-app/onyx",
source_type: "web",
blurb:
"Open-source Gen-AI platform with 30+ connectors. MIT licensed community edition.",
is_internet: true,
},
{
document_id: "web-doc-3",
semantic_identifier: "Enterprise Search Comparison 2025",
link: "https://example.com/enterprise-search-comparison",
source_type: "web",
blurb:
"Comparing top enterprise search platforms including Onyx, Glean, and Coveo.",
is_internet: true,
},
];
const WEB_SEARCH_RESPONSE = `Based on my web search, here's what I found about Onyx:
Onyx is an open-source enterprise search and AI platform that can be deployed in minutes using Docker Compose [[D1]](https://docs.onyx.app/getting-started). The project is hosted on GitHub and is MIT licensed for the community edition, with over 30 connectors available [[D2]](https://github.com/onyx-dot-app/onyx).
In comparisons with other enterprise search platforms, Onyx stands out for its open-source nature and self-hosted deployment option [[D3]](https://example.com/enterprise-search-comparison). Unlike proprietary alternatives, you maintain full control over your data and infrastructure.
Key advantages include:
- **Self-hosted**: Deploy on your own infrastructure
- **Open source**: Full visibility into the codebase [[D2]](https://github.com/onyx-dot-app/onyx)
- **Quick setup**: Get running in under 5 minutes [[D1]](https://docs.onyx.app/getting-started)
- **Extensible**: 30+ pre-built connectors with custom connector support`;
test("web search response with citations renders correctly", async ({
page,
}) => {
await openChat(page);
await page.route("**/api/chat/send-chat-message", async (route) => {
await route.fulfill({
status: 200,
contentType: "text/plain",
body: buildMockSearchStream({
content: WEB_SEARCH_RESPONSE,
queries: ["Onyx enterprise search platform overview"],
documents: WEB_SEARCH_DOCUMENTS,
citations: {
1: "web-doc-1",
2: "web-doc-2",
3: "web-doc-3",
},
isInternetSearch: true,
}),
});
});
await sendMessage(page, "Search the web for information about Onyx");
const aiMessage = page.getByTestId("onyx-ai-message").first();
await expect(aiMessage).toContainText("open-source enterprise search");
await expect(aiMessage).toContainText("Docker Compose");
await expect(aiMessage).toContainText("MIT licensed");
await screenshotChatContainer(
page,
`chat-web-search-with-citations-${theme}`
);
await screenshotToolbarButtonHoverStates(page, "chat-web-search");
});
test("internal document search response renders correctly", async ({
page,
}) => {
const internalDocs: MockDocument[] = [
{
document_id: "confluence-doc-1",
semantic_identifier: "Q3 2025 Engineering Roadmap",
link: "https://company.atlassian.net/wiki/spaces/ENG/pages/123",
source_type: "confluence",
blurb:
"Engineering priorities for Q3 include platform stability, new connector integrations, and performance improvements.",
is_internet: false,
},
{
document_id: "gdrive-doc-1",
semantic_identifier: "Platform Architecture Overview.pdf",
link: "https://drive.google.com/file/d/abc123",
source_type: "google_drive",
blurb:
"Onyx platform architecture document covering microservices, data flow, and deployment topology.",
is_internet: false,
},
];
const internalResponse = `Based on your company's internal documents, here is the engineering roadmap:
The Q3 2025 priorities focus on three main areas [[D1]](https://company.atlassian.net/wiki/spaces/ENG/pages/123):
1. **Platform stability** — Improving error handling and retry mechanisms across all connectors
2. **New integrations** — Adding support for ServiceNow and Zendesk connectors
3. **Performance** — Optimizing vector search latency and reducing indexing time
The platform architecture document provides additional context on how these improvements fit into the overall system design [[D2]](https://drive.google.com/file/d/abc123). The microservices architecture allows each component to be scaled independently.`;
await openChat(page);
await page.route("**/api/chat/send-chat-message", async (route) => {
await route.fulfill({
status: 200,
contentType: "text/plain",
body: buildMockSearchStream({
content: internalResponse,
queries: ["Q3 engineering roadmap priorities"],
documents: internalDocs,
citations: {
1: "confluence-doc-1",
2: "gdrive-doc-1",
},
isInternetSearch: false,
}),
});
});
await sendMessage(page, "What are our engineering priorities for Q3?");
const aiMessage = page.getByTestId("onyx-ai-message").first();
await expect(aiMessage).toContainText("Platform stability");
await expect(aiMessage).toContainText("New integrations");
await expect(aiMessage).toContainText("Performance");
await screenshotChatContainer(
page,
`chat-internal-search-with-citations-${theme}`
);
await screenshotToolbarButtonHoverStates(page, "chat-internal-search");
});
});
test.describe("Message Interaction States", () => {
test("hovering over user message shows action buttons", async ({
page,
}) => {
await openChat(page);
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
await sendMessage(page, SHORT_USER_MESSAGE);
const userMessage = page.locator("#onyx-human-message").first();
await userMessage.hover();
const editButton = userMessage.getByTestId("HumanMessage/edit-button");
await expect(editButton).toBeVisible({ timeout: 5000 });
await screenshotChatContainer(
page,
`chat-user-message-hover-state-${theme}`
);
});
test("AI message toolbar is visible after response completes", async ({
page,
}) => {
await openChat(page);
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
await sendMessage(page, SHORT_USER_MESSAGE);
const aiMessage = page.getByTestId("onyx-ai-message").first();
const copyButton = aiMessage.getByTestId("AgentMessage/copy-button");
const likeButton = aiMessage.getByTestId("AgentMessage/like-button");
const dislikeButton = aiMessage.getByTestId(
"AgentMessage/dislike-button"
);
await expect(copyButton).toBeVisible({ timeout: 10000 });
await expect(likeButton).toBeVisible();
await expect(dislikeButton).toBeVisible();
await screenshotChatContainer(
page,
`chat-ai-message-with-toolbar-${theme}`
);
});
});
});
}

View File

@@ -0,0 +1,218 @@
import { test, expect, Page } from "@playwright/test";
import { loginAsRandomUser } from "../utils/auth";
/**
* Builds a newline-delimited JSON stream body matching the packet
* format that useChatController expects:
*
* 1. MessageResponseIDInfo — identifies the user/assistant messages
* 2. Packet-wrapped streaming objects ({placement, obj}) — the actual content
* 3. BackendMessage — the final completed message
*
* Each line is a raw JSON object parsed by handleSSEStream.
*/
function buildMockStream(messageContent: string): string {
const packets = [
// 1. Message ID info — tells the frontend the message IDs
JSON.stringify({
user_message_id: 1,
reserved_assistant_message_id: 2,
}),
// 2. Streaming content packets wrapped in {placement, obj}
JSON.stringify({
placement: { turn_index: 0 },
obj: {
type: "message_start",
id: "mock-message-id",
content: "",
final_documents: null,
},
}),
JSON.stringify({
placement: { turn_index: 0 },
obj: {
type: "message_delta",
content: messageContent,
},
}),
JSON.stringify({
placement: { turn_index: 0 },
obj: {
type: "message_end",
},
}),
JSON.stringify({
placement: { turn_index: 0 },
obj: {
type: "stop",
stop_reason: "finished",
},
}),
// 3. Final BackendMessage — the completed message record
JSON.stringify({
message_id: 2,
message_type: "assistant",
research_type: null,
parent_message: 1,
latest_child_message: null,
message: messageContent,
rephrased_query: null,
context_docs: null,
time_sent: new Date().toISOString(),
citations: {},
files: [],
tool_call: null,
overridden_model: null,
}),
];
return packets.join("\n") + "\n";
}
/**
* Sends a message while intercepting the backend response with
* a controlled mock stream. Returns once the AI message renders.
*/
async function sendMessageWithMockResponse(
page: Page,
userMessage: string,
mockResponseContent: string
) {
const existingMessageCount = await page
.locator('[data-testid="onyx-ai-message"]')
.count();
// Intercept the send-chat-message endpoint and return our mock stream
await page.route("**/api/chat/send-chat-message", async (route) => {
await route.fulfill({
status: 200,
contentType: "application/json",
body: buildMockStream(mockResponseContent),
});
});
await page.locator("#onyx-chat-input-textarea").click();
await page.locator("#onyx-chat-input-textarea").fill(userMessage);
await page.locator("#onyx-chat-input-send-button").click();
// Wait for the AI message to appear
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(
existingMessageCount + 1,
{ timeout: 30000 }
);
// Unroute so future requests go through normally
await page.unroute("**/api/chat/send-chat-message");
}
const MOCK_FILE_ID = "00000000-0000-0000-0000-000000000001";
test.describe("File preview modal from chat file links", () => {
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAsRandomUser(page);
await page.goto("/app");
await page.waitForLoadState("networkidle");
});
test("clicking a text file link opens the TextViewModal", async ({
page,
}) => {
const mockContent = `Here is your file: [notes.txt](/api/chat/file/${MOCK_FILE_ID})`;
// Mock the file endpoint to return text content
await page.route(`**/api/chat/file/${MOCK_FILE_ID}`, async (route) => {
await route.fulfill({
status: 200,
contentType: "text/plain",
body: "Hello from the mock file!",
});
});
await sendMessageWithMockResponse(page, "Give me the file", mockContent);
// Find the link in the AI message and click it
const aiMessage = page.getByTestId("onyx-ai-message").last();
const fileLink = aiMessage.locator("a").filter({ hasText: "notes.txt" });
await expect(fileLink).toBeVisible({ timeout: 5000 });
await fileLink.click();
// Verify the modal opens
const modal = page.getByRole("dialog");
await expect(modal).toBeVisible({ timeout: 5000 });
// Verify the file name is shown in the header
await expect(modal.getByText("notes.txt")).toBeVisible();
// Verify the download button exists
await expect(modal.getByText("Download File")).toBeVisible();
// Verify the file content is rendered
await expect(modal.getByText("Hello from the mock file!")).toBeVisible();
});
test("clicking a code file link opens the CodeViewModal with syntax highlighting", async ({
page,
}) => {
const mockContent = `Here is your script: [app.py](/api/chat/file/${MOCK_FILE_ID})`;
const pythonCode = 'def hello():\n print("Hello, world!")';
// Mock the file endpoint to return Python code
await page.route(`**/api/chat/file/${MOCK_FILE_ID}`, async (route) => {
await route.fulfill({
status: 200,
contentType: "application/octet-stream",
body: pythonCode,
});
});
await sendMessageWithMockResponse(page, "Give me the script", mockContent);
// Find the link in the AI message and click it
const aiMessage = page.getByTestId("onyx-ai-message").last();
const fileLink = aiMessage.locator("a").filter({ hasText: "app.py" });
await expect(fileLink).toBeVisible({ timeout: 5000 });
await fileLink.click();
// Verify the modal opens
const modal = page.getByRole("dialog");
await expect(modal).toBeVisible({ timeout: 5000 });
// Verify the file name is shown
await expect(modal.getByText("app.py")).toBeVisible();
// Verify the code content is rendered
await expect(modal.getByText("Hello, world!")).toBeVisible();
// Verify the download button exists
await expect(modal.getByText("Download File")).toBeVisible();
});
test("download button triggers file download", async ({ page }) => {
const mockContent = `Here: [data.csv](/api/chat/file/${MOCK_FILE_ID})`;
await page.route(`**/api/chat/file/${MOCK_FILE_ID}`, async (route) => {
await route.fulfill({
status: 200,
contentType: "text/csv",
body: "name,age\nAlice,30\nBob,25",
});
});
await sendMessageWithMockResponse(page, "Give me the csv", mockContent);
const aiMessage = page.getByTestId("onyx-ai-message").last();
const fileLink = aiMessage.locator("a").filter({ hasText: "data.csv" });
await expect(fileLink).toBeVisible({ timeout: 5000 });
await fileLink.click();
const modal = page.getByRole("dialog");
await expect(modal).toBeVisible({ timeout: 5000 });
// Click the download button and verify a download starts
const downloadPromise = page.waitForEvent("download");
await modal.getByText("Download File").last().click();
const download = await downloadPromise;
expect(download.suggestedFilename()).toContain("data.csv");
});
});

View File

@@ -1,5 +1,5 @@
import { expect, Page, test } from "@playwright/test";
import { loginAs, loginAsWorkerUser } from "@tests/e2e/utils/auth";
import { loginAs } from "@tests/e2e/utils/auth";
import {
selectModelFromInputPopover,
sendMessage,
@@ -28,14 +28,10 @@ async function openChat(page: Page): Promise<void> {
async function loginWithCleanCookies(
page: Page,
user: "admin" | number
user: "admin" | "user"
): Promise<void> {
await page.context().clearCookies();
if (typeof user === "number") {
await loginAsWorkerUser(page, user);
} else {
await loginAs(page, user);
}
await loginAs(page, user);
}
async function createLlmProvider(
@@ -161,10 +157,10 @@ test.describe("LLM Runtime Selection", () => {
let providersToCleanup: number[] = [];
let groupsToCleanup: number[] = [];
test.beforeEach(async ({ page }, testInfo) => {
test.beforeEach(async ({ page }) => {
providersToCleanup = [];
groupsToCleanup = [];
await loginWithCleanCookies(page, testInfo.workerIndex);
await loginWithCleanCookies(page, "user");
});
test.afterEach(async ({ page }) => {
@@ -195,7 +191,7 @@ test.describe("LLM Runtime Selection", () => {
test("model selection persists across refresh and subsequent messages in the same chat", async ({
page,
}, testInfo) => {
}) => {
await loginWithCleanCookies(page, "admin");
const persistenceProviderName = uniqueName("PW Runtime Persist Provider");
@@ -211,7 +207,7 @@ test.describe("LLM Runtime Selection", () => {
persistenceProviderName,
]);
await loginWithCleanCookies(page, testInfo.workerIndex);
await loginWithCleanCookies(page, "user");
await openChat(page);
let turn = 0;
@@ -356,7 +352,7 @@ test.describe("LLM Runtime Selection", () => {
test("same model name across providers resolves to provider-specific runtime payloads", async ({
page,
}, testInfo) => {
}) => {
await loginWithCleanCookies(page, "admin");
const sharedModelName = `shared-runtime-model-${Date.now()}`;
@@ -383,7 +379,7 @@ test.describe("LLM Runtime Selection", () => {
anthropicProviderName,
]);
await loginWithCleanCookies(page, testInfo.workerIndex);
await loginWithCleanCookies(page, "user");
const capturedPayloads: SendChatMessagePayload[] = [];
let turn = 0;
@@ -473,7 +469,7 @@ test.describe("LLM Runtime Selection", () => {
test("restricted provider model is unavailable to unauthorized runtime user selection", async ({
page,
}, testInfo) => {
}) => {
await loginWithCleanCookies(page, "admin");
const client = new OnyxApiClient(page.request);
@@ -506,7 +502,7 @@ test.describe("LLM Runtime Selection", () => {
});
providersToCleanup.push(restrictedProviderId);
await loginWithCleanCookies(page, testInfo.workerIndex);
await loginWithCleanCookies(page, "user");
await openChat(page);
await page.getByTestId("AppInputBar/llm-popover-trigger").click();

View File

@@ -1,248 +0,0 @@
import { test, expect } from "@playwright/test";
import type { Page } from "@playwright/test";
import { loginAsRandomUser } from "../utils/auth";
import { expectElementScreenshot } from "../utils/visualRegression";
async function sendMessageAndWaitForChat(page: Page, message: string) {
await page.locator("#onyx-chat-input-textarea").click();
await page.locator("#onyx-chat-input-textarea").fill(message);
await page.locator("#onyx-chat-input-send-button").click();
await page.waitForFunction(
() => window.location.href.includes("chatId="),
null,
{ timeout: 15000 }
);
await expect(page.locator('[aria-label="share-chat-button"]')).toBeVisible({
timeout: 10000,
});
}
async function openShareModal(page: Page) {
await page.locator('[aria-label="share-chat-button"]').click();
await expect(page.getByRole("dialog")).toBeVisible({ timeout: 5000 });
}
test.describe("Share Chat Session Modal", () => {
test.describe.configure({ mode: "serial" });
let page: Page;
test.beforeAll(async ({ browser }) => {
page = await browser.newPage();
await loginAsRandomUser(page);
await sendMessageAndWaitForChat(page, "Hello for share test");
});
test.afterAll(async () => {
await page.close();
});
test("shows Private selected by default", async () => {
await openShareModal(page);
const dialog = page.getByRole("dialog");
await expect(dialog).toBeVisible();
const privateOption = dialog.locator(
'[aria-label="share-modal-option-private"]'
);
await expect(privateOption.locator("svg").last()).toBeVisible();
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
await expect(submitButton).toHaveText("Done");
const cancelButton = dialog.locator('[aria-label="share-modal-cancel"]');
await expect(cancelButton).toBeVisible();
await expectElementScreenshot(dialog, {
name: "share-modal-default-private",
});
await page.keyboard.press("Escape");
await expect(dialog).toBeHidden({ timeout: 5000 });
});
test("selecting Your Organization changes submit text", async () => {
await openShareModal(page);
const dialog = page.getByRole("dialog");
await dialog.locator('[aria-label="share-modal-option-public"]').click();
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
await expect(submitButton).toHaveText("Create Share Link");
const cancelButton = dialog.locator('[aria-label="share-modal-cancel"]');
await expect(cancelButton).toBeVisible();
await expectElementScreenshot(dialog, {
name: "share-modal-public-selected",
});
await page.keyboard.press("Escape");
await expect(dialog).toBeHidden({ timeout: 5000 });
});
test("Cancel closes modal without API calls", async () => {
let patchCallCount = 0;
await page.route("**/api/chat/chat-session/*", async (route) => {
if (route.request().method() === "PATCH") {
patchCallCount++;
}
await route.continue();
});
await openShareModal(page);
const dialog = page.getByRole("dialog");
const cancelButton = dialog.locator('[aria-label="share-modal-cancel"]');
await cancelButton.click();
await expect(dialog).toBeHidden({ timeout: 5000 });
expect(patchCallCount).toBe(0);
await page.unrouteAll({ behavior: "ignoreErrors" });
});
test("X button closes modal without API calls", async () => {
let patchCallCount = 0;
await page.route("**/api/chat/chat-session/*", async (route) => {
if (route.request().method() === "PATCH") {
patchCallCount++;
}
await route.continue();
});
await openShareModal(page);
const dialog = page.getByRole("dialog");
const closeButton = dialog.locator('div[tabindex="-1"] button');
await closeButton.click();
await expect(dialog).toBeHidden({ timeout: 5000 });
expect(patchCallCount).toBe(0);
await page.unrouteAll({ behavior: "ignoreErrors" });
});
test("creating a share link calls API and shows link", async () => {
await openShareModal(page);
const dialog = page.getByRole("dialog");
let patchBody: Record<string, unknown> | null = null;
await page.route("**/api/chat/chat-session/*", async (route) => {
if (route.request().method() === "PATCH") {
patchBody = JSON.parse(route.request().postData() ?? "{}");
await route.continue();
} else {
await route.continue();
}
});
await dialog.locator('[aria-label="share-modal-option-public"]').click();
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
await submitButton.click();
await page.waitForResponse(
(r) =>
r.url().includes("/api/chat/chat-session/") &&
r.request().method() === "PATCH",
{ timeout: 10000 }
);
expect(patchBody).toEqual({ sharing_status: "public" });
const linkInput = dialog.locator('[aria-label="share-modal-link-input"]');
await expect(linkInput).toBeVisible({ timeout: 5000 });
const inputValue = await linkInput.locator("input").inputValue();
expect(inputValue).toContain("/app/shared/");
await expect(submitButton).toHaveText("Copy Link");
await expect(dialog.getByText("Chat shared")).toBeVisible();
await expect(
dialog.locator('[aria-label="share-modal-cancel"]')
).toBeHidden();
await expectElementScreenshot(dialog, {
name: "share-modal-link-created",
mask: ['[aria-label="share-modal-link-input"]'],
});
await page.unrouteAll({ behavior: "ignoreErrors" });
// Wait for the toast to confirm SWR data has been refreshed
// before closing, so the next test sees up-to-date shared_status
await expect(
page.getByText("Share link copied to clipboard!").first()
).toBeVisible({ timeout: 5000 });
await page.keyboard.press("Escape");
await expect(dialog).toBeHidden({ timeout: 5000 });
});
test("Copy Link triggers clipboard copy", async () => {
await openShareModal(page);
const dialog = page.getByRole("dialog");
await expect(
dialog.locator('[aria-label="share-modal-link-input"]')
).toBeVisible({ timeout: 5000 });
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
await expect(submitButton).toHaveText("Copy Link");
await submitButton.click();
await expect(
page.getByText("Share link copied to clipboard!").first()
).toBeVisible({ timeout: 5000 });
await page.keyboard.press("Escape");
await expect(dialog).toBeHidden({ timeout: 5000 });
});
test("making chat private again calls API and closes modal", async () => {
let patchBody: Record<string, unknown> | null = null;
await page.route("**/api/chat/chat-session/*", async (route) => {
if (route.request().method() === "PATCH") {
patchBody = JSON.parse(route.request().postData() ?? "{}");
await route.continue();
} else {
await route.continue();
}
});
await openShareModal(page);
const dialog = page.getByRole("dialog");
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
await dialog.locator('[aria-label="share-modal-option-private"]').click();
await expect(submitButton).toHaveText("Make Private");
await submitButton.click();
await page.waitForResponse(
(r) =>
r.url().includes("/api/chat/chat-session/") &&
r.request().method() === "PATCH",
{ timeout: 10000 }
);
expect(patchBody).toEqual({ sharing_status: "private" });
await expect(dialog).toBeHidden({ timeout: 5000 });
await expect(page.getByText("Chat is now private")).toBeVisible({
timeout: 5000,
});
await page.unrouteAll({ behavior: "ignoreErrors" });
});
});

View File

@@ -1,3 +1,8 @@
export const TEST_USER_CREDENTIALS = {
email: "user1@example.com",
password: "User1Password123!",
};
export const TEST_ADMIN_CREDENTIALS = {
email: "admin_user@example.com",
password: "TestPassword123!",
@@ -7,22 +12,3 @@ export const TEST_ADMIN2_CREDENTIALS = {
email: "admin2_user@example.com",
password: "TestPassword123!",
};
/**
* Number of distinct worker users provisioned during global setup.
* Must be >= the max concurrent workers in playwright.config.ts.
* Playwright's workerIndex can exceed this (retries spawn new workers
* with incrementing indices), so callers should use modulo:
* workerIndex % WORKER_USER_POOL_SIZE
*/
export const WORKER_USER_POOL_SIZE = 8;
export function workerUserCredentials(workerIndex: number): {
email: string;
password: string;
} {
return {
email: `worker${workerIndex}@example.com`,
password: "WorkerPassword123!",
};
}

View File

@@ -2,8 +2,7 @@ import { FullConfig, request } from "@playwright/test";
import {
TEST_ADMIN_CREDENTIALS,
TEST_ADMIN2_CREDENTIALS,
WORKER_USER_POOL_SIZE,
workerUserCredentials,
TEST_USER_CREDENTIALS,
} from "@tests/e2e/constants";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
@@ -159,23 +158,23 @@ async function globalSetup(config: FullConfig) {
// ── Provision test users via API ─────────────────────────────────────
// The first user registered becomes the admin automatically.
// Order matters: admin first, then admin2, then worker users.
// Order matters: admin first, then user, then admin2.
await ensureUserExists(
baseURL,
TEST_ADMIN_CREDENTIALS.email,
TEST_ADMIN_CREDENTIALS.password
);
await ensureUserExists(
baseURL,
TEST_USER_CREDENTIALS.email,
TEST_USER_CREDENTIALS.password
);
await ensureUserExists(
baseURL,
TEST_ADMIN2_CREDENTIALS.email,
TEST_ADMIN2_CREDENTIALS.password
);
for (let i = 0; i < WORKER_USER_POOL_SIZE; i++) {
const { email, password } = workerUserCredentials(i);
await ensureUserExists(baseURL, email, password);
}
// ── Login via API and save storage state ───────────────────────────
await apiLoginAndSaveState(
baseURL,
@@ -191,6 +190,13 @@ async function globalSetup(config: FullConfig) {
TEST_ADMIN2_CREDENTIALS.email
);
await apiLoginAndSaveState(
baseURL,
TEST_USER_CREDENTIALS.email,
TEST_USER_CREDENTIALS.password,
"user_auth.json"
);
await apiLoginAndSaveState(
baseURL,
TEST_ADMIN2_CREDENTIALS.email,
@@ -198,29 +204,6 @@ async function globalSetup(config: FullConfig) {
"admin2_auth.json"
);
for (let i = 0; i < WORKER_USER_POOL_SIZE; i++) {
const { email, password } = workerUserCredentials(i);
const storageStatePath = `worker${i}_auth.json`;
await apiLoginAndSaveState(baseURL, email, password, storageStatePath);
const workerCtx = await request.newContext({
baseURL,
storageState: storageStatePath,
});
try {
const res = await workerCtx.patch("/api/user/personalization", {
data: { name: "worker" },
});
if (!res.ok()) {
console.warn(
`[global-setup] Failed to set display name for ${email}: ${res.status()}`
);
}
} finally {
await workerCtx.dispose();
}
}
// ── Ensure a public LLM provider exists ───────────────────────────
// Many tests depend on a default LLM being configured (file uploads,
// assistant creation, etc.). Re-use the admin session we just saved.

View File

@@ -1,6 +1,6 @@
import { test, expect } from "@playwright/test";
import type { Page, Browser, Locator } from "@playwright/test";
import { loginAs, loginAsWorkerUser, apiLogin } from "@tests/e2e/utils/auth";
import { loginAs, apiLogin } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import {
startMcpOauthServer,
@@ -2061,8 +2061,8 @@ test.describe("MCP OAuth flows", () => {
await page.context().clearCookies();
logStep("Cleared cookies");
await loginAsWorkerUser(page, testInfo.workerIndex);
logStep("Logged in as worker user");
await loginAs(page, "user");
logStep("Logged in as user");
const assistantId = adminArtifacts!.assistantId;
const serverName = adminArtifacts!.serverName;

View File

@@ -2,8 +2,7 @@ import type { Page } from "@playwright/test";
import {
TEST_ADMIN2_CREDENTIALS,
TEST_ADMIN_CREDENTIALS,
WORKER_USER_POOL_SIZE,
workerUserCredentials,
TEST_USER_CREDENTIALS,
} from "@tests/e2e/constants";
/**
@@ -24,34 +23,22 @@ export async function apiLogin(
}
}
// Logs in a known test user (admin or admin2) via the API.
// Logs in a known test user (admin, user, or admin2) via the API.
// Users must already be provisioned (see global-setup.ts).
export async function loginAs(
page: Page,
userType: "admin" | "admin2"
userType: "admin" | "user" | "admin2"
): Promise<void> {
const { email, password } =
userType === "admin" ? TEST_ADMIN_CREDENTIALS : TEST_ADMIN2_CREDENTIALS;
userType === "admin"
? TEST_ADMIN_CREDENTIALS
: userType === "admin2"
? TEST_ADMIN2_CREDENTIALS
: TEST_USER_CREDENTIALS;
await apiLogin(page, email, password);
}
/**
* Log in as a worker-specific user for test isolation.
* Uses modulo to map any workerIndex (which can exceed the pool size due to
* retries spawning new workers) back to a provisioned user. This is safe
* because retries never run in parallel with the original attempt.
*/
export async function loginAsWorkerUser(
page: Page,
workerIndex: number
): Promise<void> {
const { email, password } = workerUserCredentials(
workerIndex % WORKER_USER_POOL_SIZE
);
await apiLogin(page, email, password);
}
// Generate a random email and password for throwaway test users.
const generateRandomCredentials = () => {
const randomString = Math.random().toString(36).substring(2, 10);

View File

@@ -1,18 +0,0 @@
import type { Page } from "@playwright/test";
export const THEMES = ["light", "dark"] as const;
export type Theme = (typeof THEMES)[number];
/**
* Injects the given theme into localStorage via `addInitScript` so that
* `next-themes` applies it on first render. Call this in `beforeEach`
* **before** any `page.goto()`.
*/
export async function setThemeBeforeNavigation(
page: Page,
theme: Theme
): Promise<void> {
await page.addInitScript((t: string) => {
localStorage.setItem("theme", t);
}, theme);
}

View File

@@ -104,33 +104,6 @@ interface ElementScreenshotOptions {
threshold?: number;
}
/**
* Wait for all running CSS animations and transitions on the page to finish
* before proceeding. This prevents screenshot tests from being non-deterministic
* when animated elements (e.g. slide-in cards) are still mid-flight.
*
* The implementation:
* 1. Yields one animation frame so that any pending animations have a chance
* to register with the Web Animations API.
* 2. Calls `Promise.allSettled` on every active animation's `.finished`
* promise so we wait for completion (or cancellation) of all of them.
*/
export async function waitForAnimations(page: Page): Promise<void> {
await page.evaluate(async () => {
// Allow any freshly-scheduled animations to start
await new Promise<void>((resolve) =>
requestAnimationFrame(() => resolve())
);
// Wait for every currently-registered animation to finish (or be cancelled)
const animations = document
.getAnimations()
.filter(
(animation) => animation.effect?.getTiming().iterations !== Infinity
);
await Promise.allSettled(animations.map((animation) => animation.finished));
});
}
/**
* Take a screenshot and optionally assert it matches the stored baseline.
*
@@ -161,11 +134,6 @@ export async function expectScreenshot(
threshold,
} = options;
// Wait for any in-flight CSS animations / transitions to settle so that
// screenshots are deterministic (e.g. slide-in card animations on the
// onboarding flow).
await waitForAnimations(page);
// Merge default hide selectors with per-call selectors
const allHideSelectors = [...DEFAULT_HIDE_SELECTORS, ...hide];
@@ -249,10 +217,6 @@ export async function expectElementScreenshot(
const page = locator.page();
// Wait for any in-flight CSS animations / transitions to settle so that
// element screenshots are deterministic (same reasoning as expectScreenshot).
await waitForAnimations(page);
// Merge default hide selectors with per-call selectors
const allHideSelectors = [...DEFAULT_HIDE_SELECTORS, ...hide];