Compare commits

..

4 Commits

73 changed files with 699 additions and 2564 deletions

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Override Linear Check

View File

@@ -1,79 +0,0 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR body and check whether the helper checkbox is checked.
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
if: steps.gate.outputs.should_cherrypick == 'true'
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
run: |
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify

View File

@@ -0,0 +1,28 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -21,14 +21,15 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import NamedTuple
from typing import List, NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -104,6 +105,56 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -127,14 +127,9 @@ class ScimDAL(DAL):
self,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
)
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
self._session.add(mapping)
self._session.flush()
return mapping
@@ -253,11 +248,11 @@ class ScimDAL(DAL):
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
) -> tuple[list[tuple[User, str | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, mapping) pairs, total_count).
A tuple of (list of (user, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
@@ -297,104 +292,33 @@ class ScimDAL(DAL):
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
)
.unique()
.all()
).all()
)
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
) -> None:
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
)
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
return {m.user_id: m.external_id for m in mappings}
# ------------------------------------------------------------------
# Group mapping operations
@@ -559,13 +483,9 @@ class ScimDAL(DAL):
if not user_ids:
return []
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users_by_id = {u.id: u for u in users}
return [
@@ -584,13 +504,9 @@ class ScimDAL(DAL):
"""
if not uuids:
return []
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]

View File

@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 30
num_hits: int = 50
include_content: bool = False
stream: bool = False

View File

@@ -26,10 +26,12 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
@@ -39,8 +41,6 @@ from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -53,6 +53,7 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
@@ -62,18 +63,6 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
"""Resolve the SCIM provider for the current request.
Currently returns OktaProvider for all requests. When multi-provider
support is added (ENG-3652), this will resolve based on token metadata
or tenant configuration — no endpoint changes required.
"""
return get_default_provider()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@@ -111,6 +100,28 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -144,10 +155,9 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or name.formatted
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# ---------------------------------------------------------------------------
@@ -161,7 +171,6 @@ def list_users(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
@@ -174,19 +183,12 @@ def list_users(
return _scim_error_response(400, str(e))
try:
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
resources: list[ScimUserResource | ScimGroupResource] = [
provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
)
for user, mapping in users_with_mappings
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
]
return ScimListResponse(
@@ -201,7 +203,6 @@ def list_users(
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
@@ -214,26 +215,20 @@ def get_user(
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
)
return _user_to_scim(user, mapping.external_id if mapping else None)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip()
email = user_resource.userName.strip().lower()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
@@ -269,14 +264,11 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
dal.create_user_mapping(
external_id=external_id, user_id=user.id, scim_username=scim_username
)
dal.create_user_mapping(external_id=external_id, user_id=user.id)
dal.commit()
return provider.build_user_resource(user, external_id, scim_username=scim_username)
return _user_to_scim(user, external_id)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -284,7 +276,6 @@ def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
@@ -302,27 +293,19 @@ def replace_user(
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
user,
email=user_resource.userName.strip(),
email=user_resource.userName.strip().lower(),
is_active=user_resource.active,
personal_name=personal_name,
personal_name=_scim_name_to_str(user_resource.name),
)
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
dal.sync_user_external_id(user.id, new_external_id)
dal.commit()
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
return _user_to_scim(user, new_external_id)
@scim_router.patch("/Users/{user_id}", response_model=None)
@@ -330,7 +313,6 @@ def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
@@ -348,19 +330,11 @@ def patch_user(
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
)
current = _user_to_scim(user, external_id)
try:
patched = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
patched = apply_user_patch(patch_request.Operations, current)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -371,40 +345,22 @@ def patch_user(
if seat_error:
return _scim_error_response(403, seat_error)
# Track the scim_username — if userName was patched, update it
new_scim_username = patched.userName.strip() if patched.userName else None
# If displayName was explicitly patched (different from the original), use
# it as personal_name directly. Otherwise, derive from name components.
personal_name: str | None
if patched.displayName and patched.displayName != current.displayName:
personal_name = patched.displayName
else:
personal_name = _scim_name_to_str(patched.name)
dal.update_user(
user,
email=(
patched.userName.strip()
if patched.userName.strip().lower() != user.email.lower()
patched.userName.strip().lower()
if patched.userName.lower() != user.email
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=personal_name,
personal_name=_scim_name_to_str(patched.name),
)
dal.sync_user_external_id(
user.id, patched.externalId, scim_username=new_scim_username
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.commit()
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
return _user_to_scim(user, patched.externalId)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
@@ -442,6 +398,24 @@ def delete_user(
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
@@ -500,7 +474,6 @@ def list_groups(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
@@ -518,7 +491,7 @@ def list_groups(
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
@@ -534,7 +507,6 @@ def list_groups(
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
@@ -549,16 +521,13 @@ def get_group(
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
@@ -596,7 +565,7 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return provider.build_group_resource(db_group, members, external_id)
return _group_to_scim(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -604,7 +573,6 @@ def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
@@ -627,7 +595,7 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return provider.build_group_resource(group, members, group_resource.externalId)
return _group_to_scim(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -635,7 +603,6 @@ def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
@@ -654,11 +621,11 @@ def patch_group(
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = provider.build_group_resource(group, current_members, external_id)
current = _group_to_scim(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current, provider.ignored_patch_paths
patch_request.Operations, current
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -685,7 +652,7 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return provider.build_group_resource(group, members, patched.externalId)
return _group_to_scim(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)

View File

@@ -63,13 +63,6 @@ class ScimMeta(BaseModel):
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -83,10 +76,8 @@ class ScimUserResource(BaseModel):
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
@@ -130,40 +121,12 @@ class ScimPatchOperationType(str, Enum):
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: ScimPatchValue = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
class ScimPatchRequest(BaseModel):

View File

@@ -16,12 +16,9 @@ from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
@@ -44,15 +41,9 @@ _MEMBER_FILTER_RE = re.compile(
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
@@ -64,9 +55,9 @@ def apply_user_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data, ignored_paths)
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data, ignored_paths)
_apply_user_replace(op, data, name_data)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
@@ -80,34 +71,30 @@ def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data, ignored_paths)
_set_user_field(path, op.value, data, name_data)
def _set_user_field(
path: str,
value: ScimPatchValue,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on user data by SCIM path."""
if path in ignored_paths:
return
elif path == "active":
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
@@ -120,7 +107,7 @@ def _set_user_field(
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
data["displayName"] = value
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
@@ -129,15 +116,9 @@ def _set_user_field(
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
@@ -152,9 +133,7 @@ def apply_group_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
@@ -175,48 +154,38 @@ def _apply_group_replace(
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if isinstance(op.value, dict):
for key, val in op.value.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data, ignored_paths)
_set_group_field(key.lower(), val, data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
_replace_members(op.value, current_members, added_ids, removed_ids)
return
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
_set_group_field(path, op.value, data)
def _replace_members(
value: list[dict],
value: str | list | dict | bool | None,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
@@ -228,14 +197,11 @@ def _replace_members(
def _set_group_field(
path: str,
value: ScimPatchValue,
value: str | bool | dict | list | None,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
elif path == "displayname":
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
@@ -257,10 +223,8 @@ def _apply_group_add(
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in member_dicts:
for member_data in op.value:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)

View File

@@ -1,123 +0,0 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
"""
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
# Use original-case userName if stored, otherwise fall back to the
# lowercased email from the User model.
username = scim_username or user.email
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=username,
name=self._build_scim_name(user),
displayName=user.personal_name,
emails=[ScimEmail(value=username, type="work", primary=True)],
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
@staticmethod
def _build_scim_name(user: User) -> ScimName | None:
"""Extract SCIM name components from a user's personal name."""
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -1,25 +0,0 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return frozenset({"id", "schemas", "meta"})

View File

@@ -58,6 +58,14 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
response = requests.post(
@@ -65,8 +73,10 @@ class OAuthTokenManager:
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"client_id": self.oauth_config.client_id.get_value(apply_mask=False),
"client_secret": self.oauth_config.client_secret.get_value(
apply_mask=False
),
},
headers={"Accept": "application/json"},
)
@@ -115,13 +125,23 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"client_id": self.oauth_config.client_id.get_value(apply_mask=False),
"client_secret": self.oauth_config.client_secret.get_value(
apply_mask=False
),
"redirect_uri": redirect_uri,
},
headers={"Accept": "application/json"},
@@ -141,8 +161,11 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": oauth_config.client_id,
"client_id": oauth_config.client_id.get_value(apply_mask=False),
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,

View File

@@ -277,32 +277,13 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
domain = email.split("@")[-1].lower()
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
)
# Check domain whitelist if configured

View File

@@ -59,11 +59,12 @@ def _build_index_filters(
base_filters = user_provided_filters or BaseFilters()
document_set_filter = (
base_filters.document_set
if base_filters.document_set is not None
else persona_document_sets
)
if (
user_provided_filters
and user_provided_filters.document_set is None
and persona_document_sets is not None
):
base_filters.document_set = persona_document_sets
time_filter = base_filters.time_cutoff or persona_time_cutoff
source_filter = base_filters.source_type
@@ -119,7 +120,7 @@ def _build_index_filters(
user_file_ids=user_file_ids,
project_id=project_id,
source_type=source_filter,
document_set=document_set_filter,
document_set=persona_document_sets,
time_cutoff=time_filter,
tags=base_filters.tags,
access_control_list=user_acl_filters,

View File

@@ -1,102 +1,11 @@
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
def get_schemas_needing_migration(
tenant_schemas: list[str], head_rev: str
) -> list[str]:
"""Return only schemas whose current alembic version is not at head.
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
into a temp table one at a time. This avoids building a massive UNION ALL
query (which locks the DB and times out at 17k+ schemas) and instead
acquires locks sequentially, one schema per iteration.
"""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Populate a temp input table with exactly the schemas we care about.
# The DO block reads from this table so it only iterates the requested
# schemas instead of every tenant_% schema in the database.
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
conn.execute(
text(
"INSERT INTO _tenant_schemas_input (schema_name) "
"SELECT unnest(CAST(:schemas AS text[]))"
),
{"schemas": tenant_schemas},
)
conn.execute(
text(
"CREATE TEMP TABLE _alembic_version_snapshot "
"(schema_name text, version_num text)"
)
)
conn.execute(
text(
"""
DO $$
DECLARE
s text;
schemas text[];
BEGIN
SELECT array_agg(schema_name) INTO schemas
FROM _tenant_schemas_input;
IF schemas IS NULL THEN
RAISE NOTICE 'No tenant schemas found.';
RETURN;
END IF;
FOREACH s IN ARRAY schemas LOOP
BEGIN
EXECUTE format(
'INSERT INTO _alembic_version_snapshot
SELECT %L, version_num FROM %I.alembic_version',
s, s
);
EXCEPTION
-- undefined_table: schema exists but has no alembic_version
-- table yet (new tenant, not yet migrated).
-- invalid_schema_name: tenant is registered but its
-- PostgreSQL schema does not exist yet (e.g. provisioning
-- incomplete). Both cases mean no version is available and
-- the schema will be included in the migration list.
WHEN undefined_table THEN NULL;
WHEN invalid_schema_name THEN NULL;
END;
END LOOP;
END;
$$
"""
)
)
rows = conn.execute(
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
)
version_by_schema = {row[0]: row[1] for row in rows}
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
# Schemas missing from the snapshot have no alembic_version table yet and
# also need migration. version_by_schema.get(s) returns None for those,
# and None != head_rev, so they are included automatically.
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""

View File

@@ -554,9 +554,10 @@ class VespaDocumentIndex(DocumentIndex):
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = build_vespa_filters(filters)
# Avoid over-fetching a very large candidate set for global-phase reranking.
# Keep enough headroom for quality while capping cost on larger indices.
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
# Needs to be at least as much as the rerank-count value set in the
# Vespa schema config. Otherwise we would be getting fewer results than
# expected for reranking.
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
yql = (
YQL_BASE.format(index_name=self._index_name)

View File

@@ -106,13 +106,13 @@ class TestGuildDataIsolation:
# Create admin user for tenant 1
admin_user1: DATestUser = UserManager.create(
email=f"discord_admin1_{unique}@example.com",
email=f"discord_admin1+{unique}@example.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create admin user for tenant 2
admin_user2: DATestUser = UserManager.create(
email=f"discord_admin2_{unique}@example.com",
email=f"discord_admin2+{unique}@example.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
@@ -170,10 +170,10 @@ class TestGuildDataIsolation:
# Create admin users for two tenants
admin_user1: DATestUser = UserManager.create(
email=f"discord_list1_{unique}@example.com",
email=f"discord_list1+{unique}@example.com",
)
admin_user2: DATestUser = UserManager.create(
email=f"discord_list2_{unique}@example.com",
email=f"discord_list2+{unique}@example.com",
)
# Create 1 guild in tenant 1
@@ -350,10 +350,10 @@ class TestGuildAccessIsolation:
# Create admin users for two tenants
admin_user1: DATestUser = UserManager.create(
email=f"discord_access1_{unique}@example.com",
email=f"discord_access1+{unique}@example.com",
)
admin_user2: DATestUser = UserManager.create(
email=f"discord_access2_{unique}@example.com",
email=f"discord_access2+{unique}@example.com",
)
# Create a guild in tenant 1

View File

@@ -21,7 +21,7 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None: # noqa: ARG00
# Admin user invites the previously registered and non-registered user
UserManager.invite_user(invited_user.email, admin_user)
UserManager.invite_user(f"{INVITED_BASIC_USER}_{unique}@example.com", admin_user)
UserManager.invite_user(f"{INVITED_BASIC_USER}+{unique}@example.com", admin_user)
# Verify users are in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
@@ -40,7 +40,7 @@ def test_non_registered_user_gets_basic_role(
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Admin user invites a non-registered user
invited_email = f"{INVITED_BASIC_USER}_{unique}@example.com"
invited_email = f"{INVITED_BASIC_USER}+{unique}@example.com"
UserManager.invite_user(invited_email, admin_user)
# Non-registered user registers
@@ -58,7 +58,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None: # noqa: A
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Create a user to be invited
invited_user_email = f"invited_user_{unique}@example.com"
invited_user_email = f"invited_user+{unique}@example.com"
# User registers with the same email as the invitation
invited_user: DATestUser = UserManager.create(

View File

@@ -20,13 +20,13 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]: # noqa: ARG0
unique = uuid4().hex
# Creating an admin user for Tenant 1
admin_user1: DATestUser = UserManager.create(
email=f"admin_{unique}@example.com",
email=f"admin+{unique}@example.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
admin_user2: DATestUser = UserManager.create(
email=f"admin2_{unique}@example.com",
email=f"admin2+{unique}@example.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)

View File

@@ -1,167 +0,0 @@
"""
Integration tests for onyx.db.engine.tenant_utils.get_schemas_needing_migration.
These tests require a live database and exercise the function directly,
independent of the alembic migration runner script.
Usage:
pytest tests/integration/multitenant_tests/test_get_schemas_needing_migration.py -v
"""
from __future__ import annotations
import subprocess
import uuid
from collections.abc import Generator
import pytest
from sqlalchemy import text
from sqlalchemy.engine import Engine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
_BACKEND_DIR = __file__[: __file__.index("/tests/")]
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def engine() -> Engine:
return SqlEngine.get_engine()
@pytest.fixture
def current_head_rev() -> str:
result = subprocess.run(
["alembic", "heads", "--resolve-dependencies"],
cwd=_BACKEND_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
assert (
result.returncode == 0
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
rev = result.stdout.strip().split()[0]
assert len(rev) > 0
return rev
@pytest.fixture
def tenant_schema_at_head(
engine: Engine, current_head_rev: str
) -> Generator[str, None, None]:
"""Tenant schema with alembic_version already at head — should be excluded."""
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
with engine.connect() as conn:
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
conn.execute(
text(
f'CREATE TABLE "{schema}".alembic_version '
f"(version_num VARCHAR(32) NOT NULL)"
)
)
conn.execute(
text(f'INSERT INTO "{schema}".alembic_version (version_num) VALUES (:rev)'),
{"rev": current_head_rev},
)
conn.commit()
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
"""Tenant schema with no tables — should be included (needs migration)."""
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
with engine.connect() as conn:
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
conn.commit()
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
def tenant_schema_stale_rev(engine: Engine) -> Generator[str, None, None]:
"""Tenant schema with a non-head revision — should be included (needs migration)."""
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
with engine.connect() as conn:
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
conn.execute(
text(
f'CREATE TABLE "{schema}".alembic_version '
f"(version_num VARCHAR(32) NOT NULL)"
)
)
conn.execute(
text(
f'INSERT INTO "{schema}".alembic_version (version_num) '
f"VALUES ('stalerev000000000000')"
)
)
conn.commit()
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_classifies_all_cases(
current_head_rev: str,
tenant_schema_at_head: str,
tenant_schema_empty: str,
tenant_schema_stale_rev: str,
) -> None:
"""Correctly classifies all three schema states:
- at head → excluded
- no table → included (needs migration)
- stale rev → included (needs migration)
"""
all_schemas = [tenant_schema_at_head, tenant_schema_empty, tenant_schema_stale_rev]
result = get_schemas_needing_migration(all_schemas, current_head_rev)
assert tenant_schema_at_head not in result
assert tenant_schema_empty in result
assert tenant_schema_stale_rev in result
def test_idempotent(
current_head_rev: str,
tenant_schema_at_head: str,
tenant_schema_empty: str,
) -> None:
"""Calling the function twice returns the same result.
Verifies that the DROP TABLE IF EXISTS guards correctly clean up temp
tables so a second call succeeds even if the first left state behind.
"""
schemas = [tenant_schema_at_head, tenant_schema_empty]
first = get_schemas_needing_migration(schemas, current_head_rev)
second = get_schemas_needing_migration(schemas, current_head_rev)
assert first == second
def test_empty_input(current_head_rev: str) -> None:
"""An empty input list returns immediately without touching the DB."""
assert get_schemas_needing_migration([], current_head_rev) == []

View File

@@ -3,7 +3,6 @@ from fastapi import HTTPException
import onyx.auth.users as users
from onyx.auth.users import verify_email_domain
from onyx.configs.constants import AuthType
def test_verify_email_domain_allows_case_insensitive_match(
@@ -36,37 +35,3 @@ def test_verify_email_domain_invalid_email_format(
verify_email_domain("userexample.com") # missing '@'
assert exc.value.status_code == 400
assert "Email is not valid" in exc.value.detail
def test_verify_email_domain_rejects_plus_addressing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
with pytest.raises(HTTPException) as exc:
verify_email_domain("user+tag@gmail.com")
assert exc.value.status_code == 400
assert "'+'" in str(exc.value.detail)
def test_verify_email_domain_allows_plus_for_onyx_app(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
# Should not raise for onyx.app domain
verify_email_domain("user+tag@onyx.app")
def test_verify_email_domain_rejects_googlemail(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
with pytest.raises(HTTPException) as exc:
verify_email_domain("user@googlemail.com")
assert exc.value.status_code == 400
assert "gmail.com" in str(exc.value.detail)

View File

@@ -97,7 +97,6 @@ class TestScimDALUserMappings:
assert model_attrs(added_obj) == {
"external_id": "ext-1",
"user_id": user_id,
"scim_username": None,
}
def test_delete_user_mapping(

View File

@@ -15,10 +15,7 @@ from sqlalchemy.orm import Session
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.okta import OktaProvider
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
@@ -38,12 +35,6 @@ def mock_token() -> MagicMock:
return token
@pytest.fixture
def provider() -> ScimProvider:
"""An OktaProvider instance for endpoint tests."""
return OktaProvider()
@pytest.fixture
def mock_dal() -> Generator[MagicMock, None, None]:
"""Patch ScimDAL construction in api module and yield the mock instance."""
@@ -62,9 +53,6 @@ def mock_dal() -> Generator[MagicMock, None, None]:
dal.get_group_mapping_by_external_id.return_value = None
dal.get_group_members.return_value = []
dal.list_groups.return_value = ([], 0)
# User-group relationship defaults
dal.get_user_groups.return_value = []
dal.get_users_groups_batch.return_value = {}
yield dal
@@ -108,16 +96,6 @@ def make_db_group(**kwargs: Any) -> MagicMock:
return group
def make_user_mapping(**kwargs: Any) -> MagicMock:
"""Build a mock ScimUserMapping ORM object with configurable attributes."""
mapping = MagicMock(spec=ScimUserMapping)
mapping.id = kwargs.get("id", 1)
mapping.external_id = kwargs.get("external_id", "ext-default")
mapping.user_id = kwargs.get("user_id", uuid4())
mapping.scim_username = kwargs.get("scim_username", None)
return mapping
def assert_scim_error(result: object, expected_status: int) -> None:
"""Assert *result* is a JSONResponse with the given status code."""
assert isinstance(result, JSONResponse)

View File

@@ -21,7 +21,6 @@ from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import ScimProvider
from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_group
from tests.unit.onyx.server.scim.conftest import make_scim_group
@@ -35,7 +34,6 @@ class TestListGroups:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.list_groups.return_value = ([], 0)
@@ -44,7 +42,6 @@ class TestListGroups:
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -57,7 +54,6 @@ class TestListGroups:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.list_groups.side_effect = ValueError(
"Unsupported filter attribute: userName"
@@ -68,7 +64,6 @@ class TestListGroups:
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -79,7 +74,6 @@ class TestListGroups:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5, name="Engineering")
uid = uuid4()
@@ -91,7 +85,6 @@ class TestListGroups:
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -113,7 +106,6 @@ class TestGetGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5, name="Engineering")
mock_dal.get_group.return_value = group
@@ -122,7 +114,6 @@ class TestGetGroup:
result = get_group(
group_id="5",
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -135,12 +126,10 @@ class TestGetGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
provider: ScimProvider,
) -> None:
result = get_group(
group_id="not-a-number",
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -151,14 +140,12 @@ class TestGetGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group.return_value = None
result = get_group(
group_id="999",
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -175,7 +162,6 @@ class TestCreateGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group_by_name.return_value = None
mock_validate.return_value = ([], None)
@@ -186,7 +172,6 @@ class TestCreateGroup:
result = create_group(
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -200,7 +185,6 @@ class TestCreateGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group_by_name.return_value = make_db_group()
resource = make_scim_group()
@@ -208,7 +192,6 @@ class TestCreateGroup:
result = create_group(
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -221,7 +204,6 @@ class TestCreateGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group_by_name.return_value = None
mock_validate.return_value = ([], "Invalid member ID: bad-uuid")
@@ -231,7 +213,6 @@ class TestCreateGroup:
result = create_group(
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -244,7 +225,6 @@ class TestCreateGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group_by_name.return_value = None
uid = uuid4()
@@ -255,7 +235,6 @@ class TestCreateGroup:
result = create_group(
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -268,7 +247,6 @@ class TestCreateGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group_by_name.return_value = None
mock_validate.return_value = ([], None)
@@ -279,7 +257,6 @@ class TestCreateGroup:
result = create_group(
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -297,7 +274,6 @@ class TestReplaceGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5, name="Old Name")
mock_dal.get_group.return_value = group
@@ -310,7 +286,6 @@ class TestReplaceGroup:
group_id="5",
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -324,7 +299,6 @@ class TestReplaceGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group.return_value = None
@@ -332,7 +306,6 @@ class TestReplaceGroup:
group_id="999",
group_resource=make_scim_group(),
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -345,7 +318,6 @@ class TestReplaceGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
@@ -357,7 +329,6 @@ class TestReplaceGroup:
group_id="5",
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -370,7 +341,6 @@ class TestReplaceGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
@@ -383,7 +353,6 @@ class TestReplaceGroup:
group_id="5",
group_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -400,7 +369,6 @@ class TestPatchGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5, name="Old Name")
mock_dal.get_group.return_value = group
@@ -423,7 +391,6 @@ class TestPatchGroup:
group_id="5",
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -435,7 +402,6 @@ class TestPatchGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_group.return_value = None
@@ -453,7 +419,6 @@ class TestPatchGroup:
group_id="999",
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -466,7 +431,6 @@ class TestPatchGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
@@ -488,7 +452,6 @@ class TestPatchGroup:
group_id="5",
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -501,7 +464,6 @@ class TestPatchGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
@@ -521,7 +483,7 @@ class TestPatchGroup:
ScimPatchOperation(
op=ScimPatchOperationType.ADD,
path="members",
value=[ScimGroupMember(value=uid)],
value=[{"value": uid}],
)
]
)
@@ -530,7 +492,6 @@ class TestPatchGroup:
group_id="5",
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -545,7 +506,6 @@ class TestPatchGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
@@ -565,7 +525,7 @@ class TestPatchGroup:
ScimPatchOperation(
op=ScimPatchOperationType.ADD,
path="members",
value=[ScimGroupMember(value=str(uid))],
value=[{"value": str(uid)}],
)
]
)
@@ -574,7 +534,6 @@ class TestPatchGroup:
group_id="5",
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -587,7 +546,6 @@ class TestPatchGroup:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
group = make_db_group(id=5)
mock_dal.get_group.return_value = group
@@ -610,7 +568,6 @@ class TestPatchGroup:
group_id="5",
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)

View File

@@ -2,19 +2,13 @@ import pytest
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.okta import OktaProvider
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
def _make_user(**kwargs: object) -> ScimUserResource:
@@ -35,14 +29,14 @@ def _make_group(**kwargs: object) -> ScimGroupResource:
def _replace_op(
path: str | None = None,
value: ScimPatchValue = None,
value: str | bool | dict | list | None = None,
) -> ScimPatchOperation:
return ScimPatchOperation(op=ScimPatchOperationType.REPLACE, path=path, value=value)
def _add_op(
path: str | None = None,
value: ScimPatchValue = None,
value: str | bool | dict | list | None = None,
) -> ScimPatchOperation:
return ScimPatchOperation(op=ScimPatchOperationType.ADD, path=path, value=value)
@@ -86,12 +80,7 @@ class TestApplyUserPatch:
def test_replace_without_path_uses_dict(self) -> None:
user = _make_user()
result = apply_user_patch(
[
_replace_op(
None,
ScimPatchResourceValue(active=False, userName="new@example.com"),
)
],
[_replace_op(None, {"active": False, "userName": "new@example.com"})],
user,
)
assert result.active is False
@@ -130,86 +119,6 @@ class TestApplyUserPatch:
with pytest.raises(ScimPatchError, match="Unsupported operation"):
apply_user_patch([_remove_op("active")], user)
def test_replace_without_path_ignores_id(self) -> None:
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
user = _make_user()
result = apply_user_patch(
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
user,
ignored_paths=_OKTA_IGNORED,
)
assert result.active is False
def test_replace_without_path_ignores_schemas(self) -> None:
"""The 'schemas' key in a value dict should be silently ignored."""
user = _make_user()
result = apply_user_patch(
[
_replace_op(
None,
ScimPatchResourceValue(
active=False,
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
),
)
],
user,
ignored_paths=_OKTA_IGNORED,
)
assert result.active is False
def test_okta_deactivation_payload(self) -> None:
"""Exact Okta deactivation payload: path-less replace with id + active."""
user = _make_user()
result = apply_user_patch(
[
_replace_op(
None,
ScimPatchResourceValue(id="abc-123", active=False),
)
],
user,
ignored_paths=_OKTA_IGNORED,
)
assert result.active is False
assert result.userName == "test@example.com"
def test_replace_displayname(self) -> None:
user = _make_user()
result = apply_user_patch(
[_replace_op("displayName", "New Display Name")], user
)
assert result.displayName == "New Display Name"
assert result.name is not None
assert result.name.formatted == "New Display Name"
def test_replace_without_path_complex_value_dict(self) -> None:
"""Okta sends id/schemas/meta alongside actual changes — complex types
(lists, nested dicts) must not cause Pydantic validation errors."""
user = _make_user()
result = apply_user_patch(
[
_replace_op(
None,
ScimPatchResourceValue(
active=False,
id="some-uuid",
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
meta=ScimMeta(resourceType="User"),
),
)
],
user,
ignored_paths=_OKTA_IGNORED,
)
assert result.active is False
assert result.userName == "test@example.com"
def test_add_operation_works_like_replace(self) -> None:
user = _make_user()
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
assert result.externalId == "ext-456"
class TestApplyGroupPatch:
"""Tests for SCIM group PATCH operations."""
@@ -226,12 +135,7 @@ class TestApplyGroupPatch:
def test_add_members(self) -> None:
group = _make_group()
result, added, removed = apply_group_patch(
[
_add_op(
"members",
[ScimGroupMember(value="user-1"), ScimGroupMember(value="user-2")],
)
],
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
group,
)
assert len(result.members) == 2
@@ -241,7 +145,7 @@ class TestApplyGroupPatch:
def test_add_members_without_path(self) -> None:
group = _make_group()
result, added, _ = apply_group_patch(
[_add_op(None, [ScimGroupMember(value="user-1")])],
[_add_op(None, [{"value": "user-1"}])],
group,
)
assert len(result.members) == 1
@@ -250,12 +154,7 @@ class TestApplyGroupPatch:
def test_add_duplicate_member_skipped(self) -> None:
group = _make_group(members=[ScimGroupMember(value="user-1")])
result, added, _ = apply_group_patch(
[
_add_op(
"members",
[ScimGroupMember(value="user-1"), ScimGroupMember(value="user-2")],
)
],
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
group,
)
assert len(result.members) == 2
@@ -291,7 +190,7 @@ class TestApplyGroupPatch:
result, added, removed = apply_group_patch(
[
_replace_op("displayName", "Renamed"),
_add_op("members", [ScimGroupMember(value="user-2")]),
_add_op("members", [{"value": "user-2"}]),
_remove_op('members[value eq "user-1"]'),
],
group,
@@ -322,12 +221,7 @@ class TestApplyGroupPatch:
]
)
result, added, removed = apply_group_patch(
[
_replace_op(
"members",
[ScimGroupMember(value="user-2"), ScimGroupMember(value="user-3")],
)
],
[_replace_op("members", [{"value": "user-2"}, {"value": "user-3"}])],
group,
)
assert len(result.members) == 2
@@ -362,55 +256,3 @@ class TestApplyGroupPatch:
group = _make_group()
apply_group_patch([_replace_op("displayName", "Changed")], group)
assert group.displayName == "Engineering"
def test_replace_without_path_ignores_id(self) -> None:
"""Group replace with 'id' in value dict should be silently ignored."""
group = _make_group()
result, _, _ = apply_group_patch(
[
_replace_op(
None, ScimPatchResourceValue(displayName="Updated", id="some-id")
)
],
group,
ignored_paths=_OKTA_IGNORED,
)
assert result.displayName == "Updated"
def test_replace_without_path_ignores_schemas(self) -> None:
group = _make_group()
result, _, _ = apply_group_patch(
[
_replace_op(
None,
ScimPatchResourceValue(
displayName="Updated",
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
),
)
],
group,
ignored_paths=_OKTA_IGNORED,
)
assert result.displayName == "Updated"
def test_replace_without_path_complex_value_dict(self) -> None:
"""Group PATCH with complex types in value dict (lists, nested dicts)
must not cause Pydantic validation errors."""
group = _make_group()
result, _, _ = apply_group_patch(
[
_replace_op(
None,
ScimPatchResourceValue(
displayName="Updated",
id="123",
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
meta=ScimMeta(resourceType="Group"),
),
)
],
group,
ignored_paths=_OKTA_IGNORED,
)
assert result.displayName == "Updated"

View File

@@ -1,167 +0,0 @@
from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.okta import OktaProvider
def _make_mock_user(
user_id: UUID | None = None,
email: str = "test@example.com",
personal_name: str | None = "Test User",
is_active: bool = True,
) -> MagicMock:
user = MagicMock()
user.id = user_id or uuid4()
user.email = email
user.personal_name = personal_name
user.is_active = is_active
return user
def _make_mock_group(group_id: int = 42, name: str = "Engineering") -> MagicMock:
group = MagicMock()
group.id = group_id
group.name = name
return group
class TestOktaProvider:
def test_name(self) -> None:
assert OktaProvider().name == "okta"
def test_ignored_patch_paths(self) -> None:
assert OktaProvider().ignored_patch_paths == frozenset(
{"id", "schemas", "meta"}
)
def test_build_user_resource_basic(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-123")
assert result == ScimUserResource(
id=str(user.id),
externalId="ext-123",
userName="test@example.com",
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
displayName="Test User",
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
active=True,
groups=[],
meta=ScimMeta(resourceType="User"),
)
def test_build_user_resource_with_groups(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
groups = [(1, "Engineering"), (2, "Design")]
result = provider.build_user_resource(user, "ext-123", groups=groups)
assert result.groups == [
ScimUserGroupRef(value="1", display="Engineering"),
ScimUserGroupRef(value="2", display="Design"),
]
def test_build_user_resource_empty_groups(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-123", groups=[])
assert result.groups == []
def test_build_user_resource_no_groups(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-123")
assert result.groups == []
def test_build_user_resource_name_parsing(self) -> None:
provider = OktaProvider()
user = _make_mock_user(personal_name="Jane Doe")
result = provider.build_user_resource(user, None)
assert result.name == ScimName(
givenName="Jane", familyName="Doe", formatted="Jane Doe"
)
def test_build_user_resource_single_name(self) -> None:
provider = OktaProvider()
user = _make_mock_user(personal_name="Madonna")
result = provider.build_user_resource(user, None)
assert result.name == ScimName(
givenName="Madonna", familyName=None, formatted="Madonna"
)
def test_build_user_resource_no_name(self) -> None:
provider = OktaProvider()
user = _make_mock_user(personal_name=None)
result = provider.build_user_resource(user, None)
assert result.name is None
assert result.displayName is None
def test_build_user_resource_scim_username_preserves_case(self) -> None:
"""When scim_username is set, userName and emails use original case."""
provider = OktaProvider()
user = _make_mock_user(email="alice@example.com")
result = provider.build_user_resource(
user, "ext-1", scim_username="Alice@Example.com"
)
assert result.userName == "Alice@Example.com"
assert result.emails[0].value == "Alice@Example.com"
def test_build_user_resource_scim_username_none_falls_back(self) -> None:
"""When scim_username is None, userName falls back to user.email."""
provider = OktaProvider()
user = _make_mock_user(email="alice@example.com")
result = provider.build_user_resource(user, "ext-1", scim_username=None)
assert result.userName == "alice@example.com"
assert result.emails[0].value == "alice@example.com"
def test_build_group_resource(self) -> None:
provider = OktaProvider()
group = _make_mock_group()
uid1, uid2 = uuid4(), uuid4()
members: list[tuple[UUID, str | None]] = [
(uid1, "alice@example.com"),
(uid2, "bob@example.com"),
]
result = provider.build_group_resource(group, members, "ext-g-1")
assert result == ScimGroupResource(
id="42",
externalId="ext-g-1",
displayName="Engineering",
members=[
ScimGroupMember(value=str(uid1), display="alice@example.com"),
ScimGroupMember(value=str(uid2), display="bob@example.com"),
],
meta=ScimMeta(resourceType="Group"),
)
def test_build_group_resource_empty_members(self) -> None:
provider = OktaProvider()
group = _make_mock_group()
result = provider.build_group_resource(group, [])
assert result.members == []
class TestGetDefaultProvider:
def test_returns_okta(self) -> None:
provider = get_default_provider()
assert isinstance(provider, OktaProvider)

View File

@@ -9,7 +9,6 @@ from uuid import uuid4
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.server.scim.api import _scim_name_to_str
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_user
@@ -23,11 +22,9 @@ from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import ScimProvider
from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_user
from tests.unit.onyx.server.scim.conftest import make_scim_user
from tests.unit.onyx.server.scim.conftest import make_user_mapping
class TestListUsers:
@@ -38,7 +35,6 @@ class TestListUsers:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.list_users.return_value = ([], 0)
@@ -47,7 +43,6 @@ class TestListUsers:
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -60,20 +55,15 @@ class TestListUsers:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user(email="alice@example.com", personal_name="Alice Smith")
mapping = make_user_mapping(
external_id="ext-abc", user_id=user.id, scim_username="Alice@example.com"
)
mock_dal.list_users.return_value = ([(user, mapping)], 1)
mock_dal.list_users.return_value = ([(user, "ext-abc")], 1)
result = list_users(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -82,7 +72,7 @@ class TestListUsers:
assert len(result.Resources) == 1
resource = result.Resources[0]
assert isinstance(resource, ScimUserResource)
assert resource.userName == "Alice@example.com"
assert resource.userName == "alice@example.com"
assert resource.externalId == "ext-abc"
def test_unsupported_filter_attribute_returns_400(
@@ -90,7 +80,6 @@ class TestListUsers:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.list_users.side_effect = ValueError(
"Unsupported filter attribute: emails"
@@ -101,7 +90,6 @@ class TestListUsers:
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -112,14 +100,12 @@ class TestListUsers:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
provider: ScimProvider,
) -> None:
result = list_users(
filter="not a valid filter",
startIndex=1,
count=100,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -134,7 +120,6 @@ class TestGetUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user(email="alice@example.com")
mock_dal.get_user.return_value = user
@@ -142,7 +127,6 @@ class TestGetUser:
result = get_user(
user_id=str(user.id),
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -155,12 +139,10 @@ class TestGetUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
provider: ScimProvider,
) -> None:
result = get_user(
user_id="not-a-uuid",
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -171,14 +153,12 @@ class TestGetUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user.return_value = None
result = get_user(
user_id=str(uuid4()),
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -195,7 +175,6 @@ class TestCreateUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(userName="new@example.com")
@@ -203,7 +182,6 @@ class TestCreateUser:
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -217,14 +195,12 @@ class TestCreateUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
provider: ScimProvider,
) -> None:
resource = make_scim_user(externalId=None)
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -237,7 +213,6 @@ class TestCreateUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = make_db_user()
resource = make_scim_user()
@@ -245,7 +220,6 @@ class TestCreateUser:
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -258,7 +232,6 @@ class TestCreateUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = None
mock_dal.add_user.side_effect = IntegrityError("dup", {}, Exception())
@@ -267,7 +240,6 @@ class TestCreateUser:
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -281,7 +253,6 @@ class TestCreateUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
provider: ScimProvider,
) -> None:
mock_seats.return_value = "Seat limit reached"
resource = make_scim_user()
@@ -289,7 +260,6 @@ class TestCreateUser:
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -302,7 +272,6 @@ class TestCreateUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(externalId="ext-123")
@@ -310,7 +279,6 @@ class TestCreateUser:
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -327,7 +295,6 @@ class TestReplaceUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user(email="old@example.com")
mock_dal.get_user.return_value = user
@@ -340,7 +307,6 @@ class TestReplaceUser:
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -353,7 +319,6 @@ class TestReplaceUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user.return_value = None
@@ -361,7 +326,6 @@ class TestReplaceUser:
user_id=str(uuid4()),
user_resource=make_scim_user(),
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -374,7 +338,6 @@ class TestReplaceUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user(is_active=False)
mock_dal.get_user.return_value = user
@@ -385,7 +348,6 @@ class TestReplaceUser:
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -397,7 +359,6 @@ class TestReplaceUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
@@ -408,14 +369,11 @@ class TestReplaceUser:
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
mock_dal.sync_user_external_id.assert_called_once_with(
user.id, None, scim_username="test@example.com"
)
mock_dal.sync_user_external_id.assert_called_once_with(user.id, None)
class TestPatchUser:
@@ -426,7 +384,6 @@ class TestPatchUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user(is_active=True)
mock_dal.get_user.return_value = user
@@ -444,7 +401,6 @@ class TestPatchUser:
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -456,7 +412,6 @@ class TestPatchUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user.return_value = None
patch_req = ScimPatchRequest(
@@ -473,45 +428,11 @@ class TestPatchUser:
user_id=str(uuid4()),
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
assert_scim_error(result, 404)
def test_patch_displayname_persists(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""PATCH displayName should update personal_name in the DB."""
user = make_db_user(personal_name="Old Name")
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="displayName",
value="New Display Name",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
# Verify the update_user call received the new display name
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["personal_name"] == "New Display Name"
@patch("ee.onyx.server.scim.api.apply_user_patch")
def test_patch_error_returns_error_response(
self,
@@ -519,7 +440,6 @@ class TestPatchUser:
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
@@ -537,7 +457,6 @@ class TestPatchUser:
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
@@ -600,87 +519,3 @@ class TestDeleteUser:
)
assert_scim_error(result, 404)
class TestScimNameToStr:
"""Tests for _scim_name_to_str helper."""
def test_prefers_given_family_over_formatted(self) -> None:
"""Okta may send stale formatted while updating givenName/familyName."""
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
assert _scim_name_to_str(name) == "Jane Smith"
def test_given_name_only(self) -> None:
name = ScimName(givenName="Jane")
assert _scim_name_to_str(name) == "Jane"
def test_family_name_only(self) -> None:
name = ScimName(familyName="Smith")
assert _scim_name_to_str(name) == "Smith"
def test_falls_back_to_formatted(self) -> None:
name = ScimName(formatted="Display Name")
assert _scim_name_to_str(name) == "Display Name"
def test_none_returns_none(self) -> None:
assert _scim_name_to_str(None) is None
def test_empty_name_returns_none(self) -> None:
name = ScimName()
assert _scim_name_to_str(name) is None
class TestEmailCasePreservation:
"""Tests verifying email case is preserved through SCIM endpoints."""
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_create_preserves_username_case(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""POST /Users with mixed-case userName returns the original case."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(userName="Alice@Example.COM")
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "Alice@Example.COM"
assert result.emails[0].value == "Alice@Example.COM"
def test_get_preserves_username_case(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""GET /Users/{id} returns the original-case userName from mapping."""
user = make_db_user(email="alice@example.com")
mock_dal.get_user.return_value = user
mapping = make_user_mapping(
external_id="ext-1",
user_id=user.id,
scim_username="Alice@Example.COM",
)
mock_dal.get_user_mapping_by_user_id.return_value = mapping
result = get_user(
user_id=str(user.id),
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "Alice@Example.COM"
assert result.emails[0].value == "Alice@Example.COM"

View File

@@ -490,6 +490,7 @@ func createCherryPickPR(headBranch, baseBranch, title string, commitSHAs, commit
// Add standard checklist
body += "\n\n"
body += "- [x] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.\n"
body += "- [x] [Optional] Override Linear Check\n"
cmd := exec.Command("gh", "pr", "create",

View File

@@ -118,7 +118,7 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
// Create the CI branch
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] I have considered whether this PR needs to be cherry-picked to the latest beta branch.\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
// Fetch the fork's branch
if forkRepo == "" {

View File

@@ -105,18 +105,6 @@ const nextConfig = {
destination: "/app",
permanent: true,
},
// NRF routes: Redirect to /nrf which doesn't require auth
// (NRFPage handles unauthenticated users gracefully with a login modal)
{
source: "/app/nrf/side-panel",
destination: "/nrf/side-panel",
permanent: true,
},
{
source: "/app/nrf",
destination: "/nrf",
permanent: true,
},
{
source: "/chat/:path*",
destination: "/app/:path*",

View File

@@ -2,10 +2,10 @@
import Image from "next/image";
import { useMemo, useState, useReducer } from "react";
import { AdminPageTitle } from "@/components/admin/Title";
import { InfoIcon } from "@/components/icons/icons";
import Text from "@/refresh-components/texts/Text";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Content } from "@opal/layouts";
import Separator from "@/refresh-components/Separator";
import useSWR from "swr";
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
@@ -22,6 +22,7 @@ import {
SvgOnyxLogo,
SvgX,
} from "@opal/icons";
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
import {
SEARCH_PROVIDERS_URL,
@@ -401,40 +402,36 @@ export default function Page() {
: undefined);
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgGlobe}
<>
<AdminPageTitle
title="Web Search"
description="Search settings for external search across the internet."
separator
icon={SvgGlobe}
includeDivider={false}
/>
<SettingsLayouts.Body>
<Callout type="danger" title="Failed to load web search settings">
{message}
{detail && (
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
{detail}
</Text>
)}
</Callout>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
<Callout type="danger" title="Failed to load web search settings">
{message}
{detail && (
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
{detail}
</Text>
)}
</Callout>
</>
);
}
if (isLoading) {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgGlobe}
<>
<AdminPageTitle
title="Web Search"
description="Search settings for external search across the internet."
separator
icon={SvgGlobe}
includeDivider={false}
/>
<SettingsLayouts.Body>
<div className="mt-8">
<ThreeDotsLoader />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
</div>
</>
);
}
@@ -830,22 +827,32 @@ export default function Page() {
return (
<>
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgGlobe}
title="Web Search"
description="Search settings for external search across the internet."
separator
/>
<>
<AdminPageTitle icon={SvgGlobe} title="Web Search" />
<div className="pt-4 pb-4">
<Text as="p" className="text-text-dark">
Search settings for external search across the internet.
</Text>
</div>
<SettingsLayouts.Body>
<div className="flex w-full flex-col gap-3">
<Content
title="Search Engine"
description="External search engine API used for web search result URLs, snippets, and metadata."
sizePreset="main-content"
variant="section"
/>
<Separator />
<div className="flex w-full flex-col gap-8 pb-6">
<div className="flex w-full max-w-[960px] flex-col gap-3">
<div className="flex flex-col gap-0.5">
<Text as="p" mainContentEmphasis text05>
Search Engine
</Text>
<Text
as="p"
className="flex items-start gap-[2px] self-stretch text-text-03"
secondaryBody
text03
>
External search engine API used for web search result URLs,
snippets, and metadata.
</Text>
</div>
{activationError && (
<Callout type="danger" title="Unable to update default provider">
@@ -967,12 +974,14 @@ export default function Page() {
size: 16,
isHighlighted,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
<div className="flex flex-col gap-0.5">
<Text as="p" mainUiAction text05>
{label}
</Text>
<Text as="p" secondaryBody text03>
{subtitle}
</Text>
</div>
</div>
<div className="flex items-center justify-end gap-2">
{isConfigured && (
@@ -1036,13 +1045,20 @@ export default function Page() {
</div>
</div>
<div className="flex w-full flex-col gap-3">
<Content
title="Web Crawler"
description="Used to read the full contents of search result pages."
sizePreset="main-content"
variant="section"
/>
<div className="flex w-full max-w-[960px] flex-col gap-3">
<div className="flex flex-col gap-0.5">
<Text as="p" mainContentEmphasis text05>
Web Crawler
</Text>
<Text
as="p"
className="flex items-start gap-[2px] self-stretch text-text-03"
secondaryBody
text03
>
Used to read the full contents of search result pages.
</Text>
</div>
{contentActivationError && (
<Callout type="danger" title="Unable to update crawler">
@@ -1157,12 +1173,14 @@ export default function Page() {
size: 16,
isHighlighted: isCurrentCrawler,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
<div className="flex flex-col gap-0.5">
<Text as="p" mainUiAction text05>
{label}
</Text>
<Text as="p" secondaryBody text03>
{subtitle}
</Text>
</div>
</div>
<div className="flex items-center justify-end gap-2">
{provider.provider_type !== "onyx_web_crawler" &&
@@ -1226,8 +1244,8 @@ export default function Page() {
})}
</div>
</div>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
</div>
</>
<WebProviderSetupModal
isOpen={selectedProviderType !== null}

View File

@@ -382,8 +382,9 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
<IconButton
icon={SvgMenu}
onClick={toggleSettings}
secondary
tertiary
tooltip="Open settings"
className="bg-mask-02 backdrop-blur-[12px] rounded-full shadow-01 hover:bg-mask-03"
/>
</div>
)}

View File

@@ -1,26 +0,0 @@
import { unstable_noStore as noStore } from "next/cache";
import AppSidebar from "@/sections/sidebar/AppSidebar";
import { getCurrentUserSS } from "@/lib/userSS";
export interface LayoutProps {
children: React.ReactNode;
}
/**
* NRF Main (New Tab) Layout
*
* Shows the app sidebar when the user is authenticated.
* This layout is NOT used by the side-panel route.
*/
export default async function Layout({ children }: LayoutProps) {
noStore();
const user = await getCurrentUserSS();
return (
<div className="flex flex-row w-full h-full">
{user && <AppSidebar />}
{children}
</div>
);
}

View File

@@ -1,31 +0,0 @@
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import NRFPage from "@/app/app/nrf/NRFPage";
import { NRFPreferencesProvider } from "@/components/context/NRFPreferencesContext";
import NRFChrome from "../NRFChrome";
/**
* NRF (New Tab Page) Route - No Auth Required
*
* This route is placed outside /app/app/ to bypass the authentication
* requirement in /app/app/layout.tsx. The NRFPage component handles
* unauthenticated users gracefully by showing a login modal instead of
* redirecting, which is better UX for the Chrome extension.
*
* Instead of AppLayouts.Root (which pulls in heavy Header state management),
* we use NRFChrome — a lightweight overlay that renders only the search/chat
* mode toggle and footer, floating transparently over NRFPage's background.
*/
export default async function Page() {
noStore();
return (
<div className="relative w-full h-full">
<InstantSSRAutoRefresh />
<NRFPreferencesProvider>
<NRFPage />
</NRFPreferencesProvider>
<NRFChrome />
</div>
);
}

View File

@@ -1,148 +0,0 @@
"use client";
import { useState } from "react";
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
import type { Components } from "react-markdown";
import Text from "@/refresh-components/texts/Text";
import Popover from "@/refresh-components/Popover";
import { OpenButton } from "@opal/components";
import LineItem from "@/refresh-components/buttons/LineItem";
import IconButton from "@/refresh-components/buttons/IconButton";
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
import useAppFocus from "@/hooks/useAppFocus";
import { useQueryController } from "@/providers/QueryControllerProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useAppSidebarContext } from "@/providers/AppSidebarProvider";
import useScreenSize from "@/hooks/useScreenSize";
const footerMarkdownComponents = {
p: ({ children }: { children?: React.ReactNode }) => (
<Text as="p" text03 secondaryAction className="!my-0 text-center">
{children}
</Text>
),
a: ({
href,
className,
children,
...rest
}: React.AnchorHTMLAttributes<HTMLAnchorElement>) => {
const fullHref = ensureHrefProtocol(href);
return (
<a
href={fullHref}
target="_blank"
rel="noopener noreferrer"
{...rest}
className={cn(className, "underline underline-offset-2")}
>
<Text text03 secondaryAction>
{children}
</Text>
</a>
);
},
} satisfies Partial<Components>;
/**
* Lightweight chrome overlay for the NRF page.
*
* Renders only the search/chat mode toggle (top-left) and footer (bottom),
* absolutely positioned so they float transparently over NRFPage's own
* background. This avoids pulling in the full AppLayouts.Root Header which
* carries heavy state management (share/delete/move modals) that the
* extension doesn't need.
*/
export default function NRFChrome() {
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const { appMode, setAppMode } = useAppMode();
const settings = useSettingsContext();
const { isMobile } = useScreenSize();
const { setFolded } = useAppSidebarContext();
const appFocus = useAppFocus();
const { classification } = useQueryController();
const [modePopoverOpen, setModePopoverOpen] = useState(false);
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
const customFooterContent =
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
`[Onyx ${
settings?.webVersion || "dev"
}](https://www.onyx.app/) - Open Source AI Platform`;
const showModeToggle =
isPaidEnterpriseFeaturesEnabled &&
appFocus.isNewSession() &&
!classification;
const showHeader = isMobile || showModeToggle;
return (
<>
{/* Header chrome — top-left, mirrors position of settings button at top-right */}
{showHeader && (
<div className="absolute top-0 left-0 p-4 z-10 flex flex-row items-center gap-2">
{isMobile && (
<IconButton
icon={SvgSidebar}
onClick={() => setFolded(false)}
internal
/>
)}
{showModeToggle && (
<Popover open={modePopoverOpen} onOpenChange={setModePopoverOpen}>
<Popover.Trigger asChild>
<OpenButton
icon={
effectiveMode === "search" ? SvgSearchMenu : SvgBubbleText
}
>
{effectiveMode === "search" ? "Search" : "Chat"}
</OpenButton>
</Popover.Trigger>
<Popover.Content align="start" width="lg">
<Popover.Menu>
<LineItem
icon={SvgSearchMenu}
selected={effectiveMode === "search"}
description="Quick search for documents"
onClick={noProp(() => {
setAppMode("search");
setModePopoverOpen(false);
})}
>
Search
</LineItem>
<LineItem
icon={SvgBubbleText}
selected={effectiveMode === "chat"}
description="Conversation and research"
onClick={noProp(() => {
setAppMode("chat");
setModePopoverOpen(false);
})}
>
Chat
</LineItem>
</Popover.Menu>
</Popover.Content>
</Popover>
)}
</div>
)}
{/* Footer — bottom-center, transparent background */}
<footer className="absolute bottom-0 left-0 w-full z-10 flex flex-row justify-center items-center gap-2 px-2 pb-2 pointer-events-auto">
<MinimalMarkdown
content={customFooterContent}
className="max-w-full text-center"
components={footerMarkdownComponents}
/>
</footer>
</>
);
}

View File

@@ -1,15 +0,0 @@
import { ProjectsProvider } from "@/providers/ProjectsContext";
export interface LayoutProps {
children: React.ReactNode;
}
/**
* NRF Root Layout - Shared by all NRF routes
*
* Provides ProjectsProvider (needed by NRFPage) without auth redirect.
* Sidebar and chrome are handled by sub-layouts / individual pages.
*/
export default function Layout({ children }: LayoutProps) {
return <ProjectsProvider>{children}</ProjectsProvider>;
}

View File

@@ -1,24 +0,0 @@
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import NRFPage from "@/app/app/nrf/NRFPage";
import { NRFPreferencesProvider } from "@/components/context/NRFPreferencesContext";
/**
* NRF Side Panel Route - No Auth Required
*
* Side panel variant — no NRFChrome overlay needed since the side panel
* has its own header (logo + "Open in Onyx" button) and doesn't show
* the mode toggle or footer.
*/
export default async function Page() {
noStore();
return (
<>
<InstantSSRAutoRefresh />
<NRFPreferencesProvider>
<NRFPage isSidePanel />
</NRFPreferencesProvider>
</>
);
}

View File

@@ -49,7 +49,7 @@ export async function searchDocuments(
const request: SendSearchQueryRequest = {
search_query: query,
filters: options?.filters,
num_hits: options?.numHits ?? 30,
num_hits: options?.numHits ?? 50,
include_content: options?.includeContent ?? false,
stream: false,
};

View File

@@ -67,7 +67,7 @@ export function QueryControllerProvider({
searchQuery,
{
filters,
numHits: 30,
numHits: 50,
includeContent: false,
signal: controller.signal,
}

View File

@@ -36,11 +36,10 @@
import BackButton from "@/refresh-components/buttons/BackButton";
import { cn } from "@/lib/utils";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import { WithoutStyles } from "@/types";
import { IconFunctionComponent } from "@opal/types";
import { IconProps } from "@opal/types";
import { HtmlHTMLAttributes, useEffect, useRef, useState } from "react";
import { Content } from "@opal/layouts";
import Spacer from "@/refresh-components/Spacer";
const widthClasses = {
md: "w-[min(50rem,100%)]",
@@ -164,7 +163,7 @@ function SettingsRoot({ width = "md", ...props }: SettingsRootProps) {
* ```
*/
export interface SettingsHeaderProps {
icon: IconFunctionComponent;
icon: React.FunctionComponent<IconProps>;
title: string;
description?: string;
children?: React.ReactNode;
@@ -185,10 +184,7 @@ function SettingsHeader({
}: SettingsHeaderProps) {
const [showShadow, setShowShadow] = useState(false);
const headerRef = useRef<HTMLDivElement>(null);
// # NOTE (@Subash-Mohan)
// Headers with actions are always sticky, others are not.
const isSticky = !!rightChildren;
const isSticky = !!rightChildren; //headers with actions are always sticky, others are not
useEffect(() => {
if (!isSticky) return;
@@ -225,35 +221,34 @@ function SettingsHeader({
<BackButton behaviorOverride={onBack} />
</div>
)}
<Spacer vertical rem={1} />
<div className="flex flex-col gap-6 px-4">
<div className="flex w-full justify-between">
<div aria-label="admin-page-title">
<Content
icon={Icon}
title={title}
description={description}
sizePreset="headline"
variant="heading"
/>
<div
className={cn("flex flex-col gap-6 px-4", backButton ? "pt-2" : "pt-4")}
>
<div className="flex flex-col">
<div className="flex flex-row justify-between items-center gap-4">
<Icon className="stroke-text-04 h-[1.75rem] w-[1.75rem]" />
{rightChildren}
</div>
<div className={cn("flex flex-col", separator ? "pb-6" : "pb-2")}>
<div aria-label="admin-page-title">
<Text as="p" headingH2>
{title}
</Text>
</div>
{description && (
<Text secondaryBody text03>
{description}
</Text>
)}
</div>
{rightChildren}
</div>
{children}
</div>
{separator ? (
{separator && (
<>
<Spacer vertical rem={1.5} />
<Separator noPadding className="px-4" />
</>
) : (
<Spacer vertical rem={0.5} />
)}
{isSticky && (
<div
className={cn(

View File

@@ -11,10 +11,10 @@ export function getExtensionContext(): {
return { isExtension: false, context: null };
const pathname = window.location.pathname;
if (pathname.includes("/nrf/side-panel")) {
if (pathname.includes("/app/nrf/side-panel")) {
return { isExtension: true, context: "side_panel" };
}
if (pathname.includes("/nrf")) {
if (pathname.includes("/app/nrf")) {
return { isExtension: true, context: "new_tab" };
}
return { isExtension: false, context: null };

View File

@@ -240,7 +240,7 @@ export interface SendSearchQueryRequest {
filters?: BaseFilters | null;
num_docs_fed_to_llm_selection?: number | null;
run_query_expansion?: boolean;
num_hits?: number; // default 30
num_hits?: number; // default 50
include_content?: boolean;
stream?: boolean;
}

View File

@@ -114,7 +114,7 @@ const heightClasses = {
* </Modal.Content>
* ```
*/
export interface ModalContentProps
interface ModalContentProps
extends WithoutStyles<
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content>
> {

View File

@@ -1,21 +0,0 @@
import { cn } from "@/lib/utils";
interface PreviewImageProps {
src: string;
alt: string;
className?: string;
}
export default function PreviewImage({
src,
alt,
className,
}: PreviewImageProps) {
return (
<img
src={src}
alt={alt}
className={cn("object-contain object-center", className)}
/>
);
}

View File

@@ -1,8 +1,8 @@
"use client";
import Button from "@/refresh-components/buttons/Button";
import { useRouter } from "next/navigation";
import type { Route } from "next";
import { Button } from "@opal/components";
import { SvgArrowLeft } from "@opal/icons";
export interface BackButtonProps {
@@ -18,8 +18,8 @@ export default function BackButton({
return (
<Button
icon={SvgArrowLeft}
prominence="tertiary"
leftIcon={SvgArrowLeft}
tertiary
onClick={() => {
if (behaviorOverride) {
behaviorOverride();

View File

@@ -115,14 +115,9 @@ export default function ActionLineItem({
<Section gap={0.25} flexDirection="row">
{!isUnavailable && tool?.oauth_config_id && toolAuthStatus && (
<Button
icon={({ className }) => (
<SvgKey
className={cn(
className,
"stroke-yellow-500 hover:stroke-yellow-600"
)}
/>
)}
icon={SvgKey}
prominence="secondary"
size="sm"
onClick={noProp(() => {
if (
!toolAuthStatus.hasToken ||

View File

@@ -14,7 +14,7 @@ import Tabs from "@/refresh-components/Tabs";
import FilterButton from "@/refresh-components/buttons/FilterButton";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import LineItem from "@/refresh-components/buttons/LineItem";
import { Button } from "@opal/components";
import Button from "@/refresh-components/buttons/Button";
import {
SEARCH_TOOL_ID,
IMAGE_GENERATION_TOOL_ID,
@@ -428,13 +428,11 @@ export default function AgentsNavigationPage() {
title="Agents & Assistants"
description="Customize AI behavior and knowledge for you and your team's use cases."
rightChildren={
<Button
href="/app/agents/create"
icon={SvgPlus}
aria-label="AgentsPage/new-agent-button"
>
New Agent
</Button>
<div data-testid="AgentsPage/new-agent-button">
<Button href="/app/agents/create" leftIcon={SvgPlus}>
New Agent
</Button>
</div>
}
>
<div className="flex flex-col gap-2">

View File

@@ -25,7 +25,9 @@ import { AppPopup } from "@/app/app/components/AppPopup";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { useUser } from "@/providers/UserProvider";
import NoAssistantModal from "@/components/modals/NoAssistantModal";
import PreviewModal from "@/sections/modals/PreviewModal";
import TextViewModal from "@/sections/modals/TextViewModal";
import CodeViewModal from "@/sections/modals/CodeViewModal";
import { getCodeLanguage } from "@/lib/languages";
import Modal from "@/refresh-components/Modal";
import { useSendMessageToParent } from "@/lib/extension/utils";
import { SUBMIT_MESSAGE_TYPES } from "@/lib/extension/constants";
@@ -684,12 +686,18 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
</div>
)}
{presentingDocument && (
<PreviewModal
presentingDocument={presentingDocument}
onClose={() => setPresentingDocument(null)}
/>
)}
{presentingDocument &&
(getCodeLanguage(presentingDocument.semantic_identifier || "") ? (
<CodeViewModal
presentingDocument={presentingDocument}
onClose={() => setPresentingDocument(null)}
/>
) : (
<TextViewModal
presentingDocument={presentingDocument}
onClose={() => setPresentingDocument(null)}
/>
))}
{stackTraceModalContent && (
<ExceptionTraceModal

View File

@@ -0,0 +1,188 @@
"use client";
import { useState, useEffect, useMemo } from "react";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import Modal from "@/refresh-components/Modal";
import Text from "@/refresh-components/texts/Text";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { Button } from "@opal/components";
import { SvgDownload } from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import { getCodeLanguage } from "@/lib/languages";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import { CodeBlock } from "@/app/app/message/CodeBlock";
import { extractCodeText } from "@/app/app/message/codeUtils";
import { fetchChatFile } from "@/lib/chat/svc";
export interface CodeViewProps {
presentingDocument: MinimalOnyxDocument;
onClose: () => void;
}
export default function CodeViewModal({
presentingDocument,
onClose,
}: CodeViewProps) {
const [fileContent, setFileContent] = useState("");
const [fileUrl, setFileUrl] = useState("");
const [fileName, setFileName] = useState("");
const [isLoading, setIsLoading] = useState(true);
const [loadError, setLoadError] = useState<string | null>(null);
const language =
getCodeLanguage(presentingDocument.semantic_identifier || "") ||
"plaintext";
const lineCount = useMemo(() => {
if (!fileContent) return 0;
return fileContent.split("\n").length;
}, [fileContent]);
const fileSize = useMemo(() => {
if (!fileContent) return "";
const bytes = new TextEncoder().encode(fileContent).length;
if (bytes < 1024) return `${bytes} B`;
const kb = bytes / 1024;
if (kb < 1024) return `${kb.toFixed(2)} KB`;
const mb = kb / 1024;
return `${mb.toFixed(2)} MB`;
}, [fileContent]);
const headerDescription = useMemo(() => {
if (!fileContent) return "";
return `${language} - ${lineCount} ${
lineCount === 1 ? "line" : "lines"
} · ${fileSize}`;
}, [fileContent, language, lineCount, fileSize]);
useEffect(() => {
(async () => {
setIsLoading(true);
setLoadError(null);
setFileContent("");
const fileIdLocal =
presentingDocument.document_id.split("__")[1] ||
presentingDocument.document_id;
try {
const response = await fetchChatFile(fileIdLocal);
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
setFileUrl((prev) => {
if (prev) {
window.URL.revokeObjectURL(prev);
}
return url;
});
setFileName(presentingDocument.semantic_identifier || "document");
setFileContent(await blob.text());
} catch {
setLoadError("Failed to load document.");
} finally {
setIsLoading(false);
}
})();
}, [presentingDocument]);
useEffect(() => {
return () => {
if (fileUrl) {
window.URL.revokeObjectURL(fileUrl);
}
};
}, [fileUrl]);
return (
<Modal
open
onOpenChange={(open) => {
if (!open) {
onClose();
}
}}
>
<Modal.Content
width="md"
height="lg"
preventAccidentalClose={false}
onOpenAutoFocus={(e) => e.preventDefault()}
>
<Modal.Header
title={fileName || "Code"}
description={headerDescription}
onClose={onClose}
/>
<Modal.Body padding={0} gap={0}>
<Section padding={0} gap={0}>
{isLoading ? (
<Section>
<SimpleLoader className="h-8 w-8" />
</Section>
) : loadError ? (
<Section padding={1}>
<Text text03 mainUiBody>
{loadError}
</Text>
</Section>
) : (
<MinimalMarkdown
content={`\`\`\`${language}\n${fileContent}\n\`\`\``}
className="w-full h-full break-words"
components={{
code: ({
node,
className: codeClassName,
children,
...props
}: any) => {
const codeText = extractCodeText(
node,
fileContent,
children
);
return (
<CodeBlock className="" codeText={codeText}>
{children}
</CodeBlock>
);
},
}}
/>
)}
</Section>
</Modal.Body>
<Modal.Footer>
<Section
flexDirection="row"
justifyContent="between"
alignItems="center"
>
<Text text03 mainContentMuted>
{lineCount} {lineCount === 1 ? "line" : "lines"}
</Text>
<Section flexDirection="row" gap={0.5} width="fit">
<CopyIconButton
getCopyText={() => fileContent}
tooltip="Copy code"
size="sm"
/>
<a
href={fileUrl}
download={fileName || presentingDocument.document_id}
>
<Button
icon={SvgDownload}
tooltip="Download"
size="sm"
prominence="tertiary"
/>
</a>
</Section>
</Section>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,219 +0,0 @@
"use client";
import { useState, useEffect, useCallback, useMemo } from "react";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import Modal from "@/refresh-components/Modal";
import Text from "@/refresh-components/texts/Text";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { cn } from "@/lib/utils";
import { Section } from "@/layouts/general-layouts";
import { getCodeLanguage } from "@/lib/languages";
import { fetchChatFile } from "@/lib/chat/svc";
import { PreviewContext } from "@/sections/modals/PreviewModal/interfaces";
import { resolveVariant } from "@/sections/modals/PreviewModal/variants";
function resolveMimeType(mimeType: string, fileName: string): string {
if (mimeType !== "application/octet-stream") return mimeType;
const lower = fileName.toLowerCase();
if (lower.endsWith(".md") || lower.endsWith(".markdown"))
return "text/markdown";
if (lower.endsWith(".txt")) return "text/plain";
if (lower.endsWith(".csv")) return "text/csv";
return mimeType;
}
interface PreviewModalProps {
presentingDocument: MinimalOnyxDocument;
onClose: () => void;
}
export default function PreviewModal({
presentingDocument,
onClose,
}: PreviewModalProps) {
const [fileContent, setFileContent] = useState("");
const [fileUrl, setFileUrl] = useState("");
const [fileName, setFileName] = useState("");
const [isLoading, setIsLoading] = useState(true);
const [loadError, setLoadError] = useState<string | null>(null);
const [mimeType, setMimeType] = useState("application/octet-stream");
const [zoom, setZoom] = useState(100);
const variant = useMemo(
() => resolveVariant(presentingDocument.semantic_identifier, mimeType),
[presentingDocument.semantic_identifier, mimeType]
);
const language = useMemo(
() =>
getCodeLanguage(presentingDocument.semantic_identifier || "") ||
"plaintext",
[presentingDocument.semantic_identifier]
);
const lineCount = useMemo(() => {
if (!fileContent) return 0;
return fileContent.split("\n").length;
}, [fileContent]);
const fileSize = useMemo(() => {
if (!fileContent) return "";
const bytes = new TextEncoder().encode(fileContent).length;
if (bytes < 1024) return `${bytes} B`;
const kb = bytes / 1024;
if (kb < 1024) return `${kb.toFixed(2)} KB`;
const mb = kb / 1024;
return `${mb.toFixed(2)} MB`;
}, [fileContent]);
const fetchFile = useCallback(async () => {
setIsLoading(true);
setLoadError(null);
setFileContent("");
const fileIdLocal =
presentingDocument.document_id.split("__")[1] ||
presentingDocument.document_id;
try {
const response = await fetchChatFile(fileIdLocal);
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
setFileUrl((prev) => {
if (prev) window.URL.revokeObjectURL(prev);
return url;
});
const originalFileName =
presentingDocument.semantic_identifier || "document";
setFileName(originalFileName);
const rawContentType =
response.headers.get("Content-Type") || "application/octet-stream";
const resolvedMime = resolveMimeType(rawContentType, originalFileName);
setMimeType(resolvedMime);
const resolved = resolveVariant(
presentingDocument.semantic_identifier,
resolvedMime
);
if (resolved.needsTextContent) {
setFileContent(await blob.text());
}
} catch {
setLoadError("Failed to load document.");
} finally {
setIsLoading(false);
}
}, [presentingDocument]);
useEffect(() => {
fetchFile();
}, [fetchFile]);
useEffect(() => {
return () => {
if (fileUrl) window.URL.revokeObjectURL(fileUrl);
};
}, [fileUrl]);
const handleZoomIn = useCallback(
() => setZoom((prev) => Math.min(prev + 25, 200)),
[]
);
const handleZoomOut = useCallback(
() => setZoom((prev) => Math.max(prev - 25, 25)),
[]
);
const ctx: PreviewContext = useMemo(
() => ({
fileContent,
fileUrl,
fileName,
language,
lineCount,
fileSize,
zoom,
onZoomIn: handleZoomIn,
onZoomOut: handleZoomOut,
}),
[
fileContent,
fileUrl,
fileName,
language,
lineCount,
fileSize,
zoom,
handleZoomIn,
handleZoomOut,
]
);
return (
<Modal
open
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<Modal.Content
width={variant.width}
height={variant.height}
preventAccidentalClose={false}
onOpenAutoFocus={(e) => e.preventDefault()}
>
<Modal.Header
title={fileName || "Document"}
description={variant.headerDescription(ctx)}
onClose={onClose}
/>
{/* Body + floating footer wrapper */}
<Modal.Body padding={0} gap={0}>
<Section padding={0} gap={0}>
{isLoading ? (
<Section>
<SimpleLoader className="h-8 w-8" />
</Section>
) : loadError ? (
<Section padding={1}>
<Text text03 mainUiBody>
{loadError}
</Text>
</Section>
) : (
variant.renderContent(ctx)
)}
</Section>
</Modal.Body>
{/* Floating footer */}
{!isLoading && !loadError && (
<div
className={cn(
"absolute bottom-0 left-0 right-0",
"flex items-center justify-between",
"p-4 pointer-events-none w-full"
)}
style={{
background:
"linear-gradient(to top, var(--background-tint-01) 40%, transparent)",
}}
>
{/* Left slot */}
<div className="pointer-events-auto">
{variant.renderFooterLeft(ctx)}
</div>
{/* Right slot */}
<div className="pointer-events-auto rounded-12 bg-background-tint-00 p-1 shadow-lg">
{variant.renderFooterRight(ctx)}
</div>
</div>
)}
</Modal.Content>
</Modal>
);
}

View File

@@ -1 +0,0 @@
export { default } from "@/sections/modals/PreviewModal/PreviewModal";

View File

@@ -1,30 +0,0 @@
import React from "react";
import { ModalContentProps } from "@/refresh-components/Modal";
export interface PreviewContext {
fileContent: string;
fileUrl: string;
fileName: string;
language: string;
lineCount: number;
fileSize: string;
zoom: number;
onZoomIn: () => void;
onZoomOut: () => void;
}
export interface PreviewVariant
extends Required<Pick<ModalContentProps, "width" | "height">> {
/** Return true if this variant should handle the given file. */
matches: (semanticIdentifier: string | null, mimeType: string) => boolean;
/** Whether the fetcher should read the blob as text. */
needsTextContent: boolean;
/** String shown below the title in the modal header. */
headerDescription: (ctx: PreviewContext) => string;
/** Body content. */
renderContent: (ctx: PreviewContext) => React.ReactNode;
/** Left side of the floating footer (e.g. line count text, zoom controls). Return null for nothing. */
renderFooterLeft: (ctx: PreviewContext) => React.ReactNode;
/** Right side of the floating footer (e.g. copy + download buttons). */
renderFooterRight: (ctx: PreviewContext) => React.ReactNode;
}

View File

@@ -1,55 +0,0 @@
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import Text from "@/refresh-components/texts/Text";
import { Section } from "@/layouts/general-layouts";
import { getCodeLanguage } from "@/lib/languages";
import { CodeBlock } from "@/app/app/message/CodeBlock";
import { extractCodeText } from "@/app/app/message/codeUtils";
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import {
CopyButton,
DownloadButton,
} from "@/sections/modals/PreviewModal/variants/shared";
export const codeVariant: PreviewVariant = {
matches: (name) => !!getCodeLanguage(name || ""),
width: "md",
height: "lg",
needsTextContent: true,
headerDescription: (ctx) =>
ctx.fileContent
? `${ctx.language} - ${ctx.lineCount} ${
ctx.lineCount === 1 ? "line" : "lines"
} · ${ctx.fileSize}`
: "",
renderContent: (ctx) => (
<MinimalMarkdown
content={`\`\`\`${ctx.language}\n${ctx.fileContent}\n\n\`\`\``}
className="w-full break-words h-full"
components={{
code: ({ node, children }: any) => {
const codeText = extractCodeText(node, ctx.fileContent, children);
return (
<CodeBlock className="" codeText={codeText}>
{children}
</CodeBlock>
);
},
}}
/>
),
renderFooterLeft: (ctx) => (
<Text text03 mainUiBody className="select-none">
{ctx.lineCount} {ctx.lineCount === 1 ? "line" : "lines"}
</Text>
),
renderFooterRight: (ctx) => (
<Section flexDirection="row" width="fit">
<CopyButton getText={() => ctx.fileContent} />
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
</Section>
),
};

View File

@@ -1,102 +0,0 @@
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { Section } from "@/layouts/general-layouts";
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import {
CopyButton,
DownloadButton,
} from "@/sections/modals/PreviewModal/variants/shared";
interface CsvData {
headers: string[];
rows: string[][];
}
function parseCsv(content: string): CsvData {
const lines = content.split(/\r?\n/).filter((l) => l.length > 0);
const headers = lines.length > 0 ? lines[0]?.split(",") ?? [] : [];
const rows = lines.slice(1).map((line) => line.split(","));
return { headers, rows };
}
export const csvVariant: PreviewVariant = {
matches: (name, mime) =>
mime.startsWith("text/csv") || (name || "").toLowerCase().endsWith(".csv"),
width: "lg",
height: "full",
needsTextContent: true,
headerDescription: (ctx) => {
if (!ctx.fileContent) return "";
const { rows } = parseCsv(ctx.fileContent);
return `CSV - ${rows.length} rows · ${ctx.fileSize}`;
},
renderContent: (ctx) => {
if (!ctx.fileContent) return null;
const { headers, rows } = parseCsv(ctx.fileContent);
return (
<Section justifyContent="start" alignItems="start" padding={1}>
<Table>
<TableHeader className="sticky top-0 z-sticky">
<TableRow className="bg-background-tint-02">
{headers.map((h: string, i: number) => (
<TableHead key={i}>
<Text
as="p"
className="line-clamp-2 font-medium"
text03
mainUiBody
>
{h}
</Text>
</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
{rows.map((row: string[], rIdx: number) => (
<TableRow key={rIdx}>
{headers.map((_: string, cIdx: number) => (
<TableCell
key={cIdx}
className={cn(
cIdx === 0 && "sticky left-0 bg-background-tint-01",
"py-0 px-4 whitespace-normal break-words"
)}
>
{row?.[cIdx] ?? ""}
</TableCell>
))}
</TableRow>
))}
</TableBody>
</Table>
</Section>
);
},
renderFooterLeft: (ctx) => {
if (!ctx.fileContent) return null;
const { headers, rows } = parseCsv(ctx.fileContent);
return (
<Text text03 mainUiBody className="select-none">
{headers.length} {headers.length === 1 ? "column" : "columns"} ·{" "}
{rows.length} {rows.length === 1 ? "row" : "rows"}
</Text>
);
},
renderFooterRight: (ctx) => (
<Section flexDirection="row" width="fit">
<CopyButton getText={() => ctx.fileContent} />
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
</Section>
),
};

View File

@@ -1,45 +0,0 @@
import { Section } from "@/layouts/general-layouts";
import PreviewImage from "@/refresh-components/PreviewImage";
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import {
DownloadButton,
ZoomControls,
} from "@/sections/modals/PreviewModal/variants/shared";
export const imageVariant: PreviewVariant = {
matches: (_name, mime) => mime.startsWith("image/"),
width: "lg",
height: "full",
needsTextContent: false,
headerDescription: () => "",
renderContent: (ctx) => (
<div
className="flex flex-1 min-h-0 items-center justify-center p-4 transition-transform duration-300 ease-in-out"
style={{
transform: `scale(${ctx.zoom / 100})`,
transformOrigin: "center",
}}
>
<PreviewImage
src={ctx.fileUrl}
alt={ctx.fileName}
className="max-w-full max-h-full"
/>
</div>
),
renderFooterLeft: (ctx) => (
<ZoomControls
zoom={ctx.zoom}
onZoomIn={ctx.onZoomIn}
onZoomOut={ctx.onZoomOut}
/>
),
renderFooterRight: (ctx) => (
<Section flexDirection="row" width="fit">
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
</Section>
),
};

View File

@@ -1,25 +0,0 @@
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import { codeVariant } from "@/sections/modals/PreviewModal/variants/codeVariant";
import { imageVariant } from "@/sections/modals/PreviewModal/variants/imageVariant";
import { pdfVariant } from "@/sections/modals/PreviewModal/variants/pdfVariant";
import { csvVariant } from "@/sections/modals/PreviewModal/variants/csvVariant";
import { markdownVariant } from "@/sections/modals/PreviewModal/variants/markdownVariant";
import { unsupportedVariant } from "@/sections/modals/PreviewModal/variants/unsupportedVariant";
const PREVIEW_VARIANTS: PreviewVariant[] = [
codeVariant,
imageVariant,
pdfVariant,
csvVariant,
markdownVariant,
];
export function resolveVariant(
semanticIdentifier: string | null,
mimeType: string
): PreviewVariant {
return (
PREVIEW_VARIANTS.find((v) => v.matches(semanticIdentifier, mimeType)) ??
unsupportedVariant
);
}

View File

@@ -1,52 +0,0 @@
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
import { Section } from "@/layouts/general-layouts";
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import {
CopyButton,
DownloadButton,
} from "@/sections/modals/PreviewModal/variants/shared";
const MARKDOWN_MIMES = [
"text/markdown",
"text/x-markdown",
"text/plain",
"text/x-rst",
"text/x-org",
];
export const markdownVariant: PreviewVariant = {
matches: (name, mime) => {
if (MARKDOWN_MIMES.some((m) => mime.startsWith(m))) return true;
const lower = (name || "").toLowerCase();
return (
lower.endsWith(".md") ||
lower.endsWith(".markdown") ||
lower.endsWith(".txt") ||
lower.endsWith(".rst") ||
lower.endsWith(".org")
);
},
width: "lg",
height: "full",
needsTextContent: true,
headerDescription: () => "",
renderContent: (ctx) => (
<ScrollIndicatorDiv className="flex-1 min-h-0 p-4" variant="shadow">
<MinimalMarkdown
content={ctx.fileContent}
className="w-full pb-4 h-full text-lg break-words"
/>
</ScrollIndicatorDiv>
),
renderFooterLeft: () => null,
renderFooterRight: (ctx) => (
<Section flexDirection="row" width="fit">
<CopyButton getText={() => ctx.fileContent} />
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
</Section>
),
};

View File

@@ -1,26 +0,0 @@
import { Section } from "@/layouts/general-layouts";
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import { DownloadButton } from "@/sections/modals/PreviewModal/variants/shared";
export const pdfVariant: PreviewVariant = {
matches: (_name, mime) => mime === "application/pdf",
width: "lg",
height: "full",
needsTextContent: false,
headerDescription: () => "",
renderContent: (ctx) => (
<iframe
src={`${ctx.fileUrl}#toolbar=0`}
className="w-full h-full flex-1 min-h-0 border-none"
title="PDF Viewer"
/>
),
renderFooterLeft: () => null,
renderFooterRight: (ctx) => (
<Section flexDirection="row" width="fit">
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
</Section>
),
};

View File

@@ -1,65 +0,0 @@
import { Button } from "@opal/components";
import { SvgDownload, SvgZoomIn, SvgZoomOut } from "@opal/icons";
import Text from "@/refresh-components/texts/Text";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import { Section } from "@/layouts/general-layouts";
interface DownloadButtonProps {
fileUrl: string;
fileName: string;
}
export function DownloadButton({ fileUrl, fileName }: DownloadButtonProps) {
return (
<a href={fileUrl} download={fileName}>
<Button
prominence="tertiary"
size="sm"
icon={SvgDownload}
tooltip="Download"
/>
</a>
);
}
interface CopyButtonProps {
getText: () => string;
}
export function CopyButton({ getText }: CopyButtonProps) {
return (
<CopyIconButton getCopyText={getText} tooltip="Copy content" size="sm" />
);
}
interface ZoomControlsProps {
zoom: number;
onZoomIn: () => void;
onZoomOut: () => void;
}
export function ZoomControls({ zoom, onZoomIn, onZoomOut }: ZoomControlsProps) {
return (
<div className="rounded-12 bg-background-tint-00 p-1 shadow-lg">
<Section flexDirection="row" width="fit">
<Button
prominence="tertiary"
size="sm"
icon={SvgZoomOut}
onClick={onZoomOut}
tooltip="Zoom Out"
/>
<Text mainUiMono text03>
{zoom}%
</Text>
<Button
prominence="tertiary"
size="sm"
icon={SvgZoomIn}
onClick={onZoomIn}
tooltip="Zoom In"
/>
</Section>
</div>
);
}

View File

@@ -1,28 +0,0 @@
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
import { DownloadButton } from "@/sections/modals/PreviewModal/variants/shared";
export const unsupportedVariant: PreviewVariant = {
matches: () => true,
width: "lg",
height: "full",
needsTextContent: false,
headerDescription: () => "",
renderContent: (ctx) => (
<div className="flex flex-col items-center justify-center flex-1 min-h-0 gap-4 p-6">
<Text as="p" text03 mainUiBody>
This file format is not supported for preview.
</Text>
<a href={ctx.fileUrl} download={ctx.fileName}>
<Button>Download File</Button>
</a>
</div>
),
renderFooterLeft: () => null,
renderFooterRight: (ctx) => (
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
),
};

View File

@@ -21,7 +21,6 @@ import {
SvgZoomIn,
SvgZoomOut,
} from "@opal/icons";
import PreviewImage from "@/refresh-components/PreviewImage";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
import { cn } from "@/lib/utils";
@@ -250,10 +249,10 @@ export default function TextViewModal({
style={{ transform: `scale(${zoom / 100})` }}
>
{isImageFormat(fileType) ? (
<PreviewImage
<img
src={fileUrl}
alt={fileName}
className="w-full flex-1 min-h-0"
className="w-full flex-1 min-h-0 object-contain object-center"
/>
) : isSupportedIframeFormat(fileType) ? (
<iframe

View File

@@ -136,7 +136,7 @@ async function verifyAdminPageNavigation(
try {
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
new RegExp(`^${pageTitle}`),
pageTitle,
{
timeout: 10000,
}

View File

@@ -21,12 +21,12 @@ function getToolSwitch(page: Page, toolName: string): Locator {
}
/**
* Click a button and wait for the PATCH response to complete.
* Click a tool switch and wait for the PATCH response to complete.
* Uses waitForResponse set up *before* the click to avoid race conditions.
*/
async function clickAndWaitForPatch(
async function clickToolSwitchAndWaitForSave(
page: Page,
buttonLocator: Locator
switchLocator: Locator
): Promise<void> {
const patchPromise = page.waitForResponse(
(r) =>
@@ -34,7 +34,7 @@ async function clickAndWaitForPatch(
r.request().method() === "PATCH",
{ timeout: 8000 }
);
await buttonLocator.click();
await switchLocator.click();
await patchPromise;
}
@@ -161,7 +161,7 @@ test.describe("Chat Preferences Admin Page", () => {
}) => {
// Verify page loads with expected content
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
/^Chat Preferences/
"Chat Preferences"
);
await expect(page.getByText("Actions & Tools")).toBeVisible();
});
@@ -215,7 +215,7 @@ test.describe("Chat Preferences Admin Page", () => {
);
// Toggle back to original state
await clickAndWaitForPatch(page, searchSwitch);
await clickToolSwitchAndWaitForSave(page, searchSwitch);
});
test("should toggle Web Search tool on and off", async ({ page }) => {
@@ -267,7 +267,7 @@ test.describe("Chat Preferences Admin Page", () => {
);
// Toggle back to original state
await clickAndWaitForPatch(page, webSearchSwitch);
await clickToolSwitchAndWaitForSave(page, webSearchSwitch);
});
test("should toggle Image Generation tool on and off", async ({ page }) => {
@@ -321,7 +321,7 @@ test.describe("Chat Preferences Admin Page", () => {
);
// Toggle back to original state
await clickAndWaitForPatch(page, imageGenSwitch);
await clickToolSwitchAndWaitForSave(page, imageGenSwitch);
});
test("should edit and save system prompt", async ({ page }) => {
@@ -339,12 +339,30 @@ test.describe("Chat Preferences Admin Page", () => {
const textarea = modal.getByPlaceholder("Enter your system prompt...");
await textarea.fill(testPrompt);
// Click Save and wait for PATCH to complete
await clickAndWaitForPatch(
page,
modal.getByRole("button", { name: "Save" })
// Set up response listener before the click to avoid race conditions
const patchRespPromise = page.waitForResponse(
(r) =>
r.url().includes("/api/admin/default-assistant") &&
r.request().method() === "PATCH",
{ timeout: 8000 }
);
// Click Save in the modal footer
await modal.getByRole("button", { name: "Save" }).click();
// Wait for PATCH to complete
const patchResp = await patchRespPromise;
console.log(
`[prompt] Save PATCH status=${patchResp.status()} body=${(
await patchResp.text()
).slice(0, 300)}`
);
// Wait for success toast
await expect(page.getByText("System prompt updated")).toBeVisible({
timeout: 5000,
});
// Modal should close after save
await expect(modal).not.toBeVisible();
@@ -378,10 +396,11 @@ test.describe("Chat Preferences Admin Page", () => {
// If already empty, add some text first
if (initialValue === "") {
await textarea.fill("Temporary text");
await clickAndWaitForPatch(
page,
modal.getByRole("button", { name: "Save" })
);
await modal.getByRole("button", { name: "Save" }).click();
await expect(page.getByText("System prompt updated")).toBeVisible({
timeout: 5000,
});
await page.waitForTimeout(500);
// Reopen modal
await page.getByText("Modify Prompt").click();
await expect(modal).toBeVisible({ timeout: 5000 });
@@ -390,12 +409,28 @@ test.describe("Chat Preferences Admin Page", () => {
// Clear the textarea
await textarea.fill("");
// Save
await clickAndWaitForPatch(
page,
modal.getByRole("button", { name: "Save" })
// Set up response listener before the click to avoid race conditions
const patchRespPromise = page.waitForResponse(
(r) =>
r.url().includes("/api/admin/default-assistant") &&
r.request().method() === "PATCH",
{ timeout: 8000 }
);
// Save
await modal.getByRole("button", { name: "Save" }).click();
const patchResp = await patchRespPromise;
console.log(
`[prompt-empty] Save empty PATCH status=${patchResp.status()} body=${(
await patchResp.text()
).slice(0, 300)}`
);
await expect(page.getByText("System prompt updated")).toBeVisible({
timeout: 5000,
});
// Refresh page to verify persistence
await page.reload();
await page.waitForLoadState("networkidle");
@@ -415,10 +450,10 @@ test.describe("Chat Preferences Admin Page", () => {
// Restore original value if it wasn't already empty
if (initialValue !== "") {
await textareaAfter.fill(initialValue);
await clickAndWaitForPatch(
page,
modalAfter.getByRole("button", { name: "Save" })
);
await modalAfter.getByRole("button", { name: "Save" }).click();
await expect(page.getByText("System prompt updated")).toBeVisible({
timeout: 5000,
});
} else {
await modalAfter.getByRole("button", { name: "Cancel" }).click();
}
@@ -440,12 +475,27 @@ test.describe("Chat Preferences Admin Page", () => {
await textarea.fill(longPrompt);
// Save
await clickAndWaitForPatch(
page,
modal.getByRole("button", { name: "Save" })
// Set up response listener before the click to avoid race conditions
const patchRespPromise = page.waitForResponse(
(r) =>
r.url().includes("/api/admin/default-assistant") &&
r.request().method() === "PATCH",
{ timeout: 8000 }
);
// Save
await modal.getByRole("button", { name: "Save" }).click();
const patchResp = await patchRespPromise;
console.log(
`[prompt-long] Save PATCH status=${patchResp.status()} body=${(
await patchResp.text()
).slice(0, 300)}`
);
await expect(page.getByText("System prompt updated")).toBeVisible({
timeout: 5000,
});
// Verify persistence after reload
await page.reload();
await page.waitForLoadState("networkidle");
@@ -463,10 +513,10 @@ test.describe("Chat Preferences Admin Page", () => {
"Enter your system prompt..."
);
await restoreTextarea.fill(initialValue);
await clickAndWaitForPatch(
page,
modalAfter.getByRole("button", { name: "Save" })
);
await modalAfter.getByRole("button", { name: "Save" }).click();
await expect(page.getByText("System prompt updated")).toBeVisible({
timeout: 5000,
});
} else {
await modalAfter.getByRole("button", { name: "Cancel" }).click();
}
@@ -552,7 +602,7 @@ test.describe("Chat Preferences Admin Page", () => {
const toolSwitch = getToolSwitch(page, toolName);
const currentState = await toolSwitch.getAttribute("aria-checked");
if (currentState === "true") {
await clickAndWaitForPatch(page, toolSwitch);
await clickToolSwitchAndWaitForSave(page, toolSwitch);
const newState = await toolSwitch.getAttribute("aria-checked");
console.log(`[toggle-all] Clicked ${toolName}, new state=${newState}`);
}
@@ -578,7 +628,7 @@ test.describe("Chat Preferences Admin Page", () => {
const toolSwitch = getToolSwitch(page, toolName);
const currentState = await toolSwitch.getAttribute("aria-checked");
if (currentState === "false") {
await clickAndWaitForPatch(page, toolSwitch);
await clickToolSwitchAndWaitForSave(page, toolSwitch);
const newState = await toolSwitch.getAttribute("aria-checked");
console.log(`[toggle-all] Clicked ${toolName}, new state=${newState}`);
}
@@ -672,7 +722,7 @@ test.describe("Chat Preferences Admin Page", () => {
const originalState = toolStates[toolName];
if (currentState !== originalState) {
await clickAndWaitForPatch(page, toolSwitch);
await clickToolSwitchAndWaitForSave(page, toolSwitch);
needsSave = true;
}
}

View File

@@ -190,7 +190,7 @@ test.describe("Disable Default Assistant Setting @exclusive", () => {
// Wait for the page to fully render (page title signals form is loaded)
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
/^Chat Preferences/,
"Chat Preferences",
{ timeout: 10000 }
);
@@ -224,7 +224,7 @@ test.describe("Disable Default Assistant Setting @exclusive", () => {
// Verify the page title
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
/^Chat Preferences/
"Chat Preferences"
);
});

View File

@@ -105,7 +105,7 @@ test.describe("LLM Provider Setup @exclusive", () => {
await loginAs(page, "admin");
await page.goto(LLM_SETUP_URL);
await page.waitForLoadState("networkidle");
await expect(page.getByLabel("admin-page-title")).toHaveText(/^LLM Setup/);
await expect(page.getByLabel("admin-page-title")).toHaveText("LLM Setup");
});
test.afterEach(async ({ page }) => {

View File

@@ -137,7 +137,7 @@ test.describe("Signup flow", () => {
// Wait for error message to appear
await expect(
page.getByText("Disposable email addresses are not allowed").first()
page.getByText("Unknown error", { exact: true })
).toBeVisible();
// Capture the error state with hidden email to avoid non-deterministic diffs

View File

@@ -107,7 +107,7 @@ test.describe("Default Assistant Tests", () => {
// Create a custom assistant to test non-default behavior
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page.getByTestId("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
@@ -150,7 +150,7 @@ test.describe("Default Assistant Tests", () => {
}) => {
// Create a custom assistant
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page.getByTestId("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
@@ -200,7 +200,7 @@ test.describe("Default Assistant Tests", () => {
}) => {
// Create a custom assistant with starter messages
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page.getByTestId("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });
@@ -253,7 +253,7 @@ test.describe("Default Assistant Tests", () => {
// Wait for modal or assistant list to appear
// The selector might be in a modal or dropdown.
await page
.getByLabel("AgentsPage/new-agent-button")
.getByTestId("AgentsPage/new-agent-button")
.waitFor({ state: "visible", timeout: 5000 });
// Look for default assistant by name - it should NOT be there
@@ -280,7 +280,7 @@ test.describe("Default Assistant Tests", () => {
}) => {
// Create a custom assistant
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page.getByTestId("AgentsPage/new-agent-button").click();
await page
.locator('input[name="name"]')
.waitFor({ state: "visible", timeout: 10000 });

View File

@@ -143,14 +143,14 @@ test.describe("File preview modal from chat file links", () => {
// Verify the file name is shown in the header
await expect(modal.getByText("notes.txt")).toBeVisible();
// Verify the download link exists
await expect(modal.locator("a[download]")).toBeVisible();
// Verify the download button exists
await expect(modal.getByText("Download File")).toBeVisible();
// Verify the file content is rendered
await expect(modal.getByText("Hello from the mock file!")).toBeVisible();
});
test("clicking a code file link opens the PreviewModal with syntax highlighting", async ({
test("clicking a code file link opens the CodeViewModal with syntax highlighting", async ({
page,
}) => {
const mockContent = `Here is your script: [app.py](/api/chat/file/${MOCK_FILE_ID})`;
@@ -173,7 +173,7 @@ test.describe("File preview modal from chat file links", () => {
await expect(fileLink).toBeVisible({ timeout: 5000 });
await fileLink.click();
// Verify the PreviewModal opens
// Verify the CodeViewModal opens
const modal = page.getByRole("dialog");
await expect(modal).toBeVisible({ timeout: 5000 });
@@ -217,9 +217,9 @@ test.describe("File preview modal from chat file links", () => {
const modal = page.getByRole("dialog");
await expect(modal).toBeVisible({ timeout: 5000 });
// Click the download link and verify a download starts
// Click the download button and verify a download starts
const downloadPromise = page.waitForEvent("download");
await modal.locator("a[download]").last().click();
await modal.getByText("Download File").last().click();
const download = await downloadPromise;
expect(download.suggestedFilename()).toContain("data.csv");

View File

@@ -32,7 +32,7 @@ test("Chat workflow", async ({ page }) => {
// Test creation of a new assistant
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page.getByTestId("AgentsPage/new-agent-button").click();
await page.locator('input[name="name"]').click();
await page.locator('input[name="name"]').fill("Test Assistant");
await page.locator('textarea[name="description"]').click();

View File

@@ -9,7 +9,6 @@ import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
const PREFLIGHT_TIMEOUT_MS = 60_000;
const PREFLIGHT_POLL_INTERVAL_MS = 2_000;
const PREFLIGHT_WARN_AFTER_MS = 15_000;
/**
* Poll the health endpoint until the server is ready or we time out.
@@ -18,8 +17,6 @@ const PREFLIGHT_WARN_AFTER_MS = 15_000;
async function waitForServer(baseURL: string): Promise<void> {
const healthURL = baseURL;
const deadline = Date.now() + PREFLIGHT_TIMEOUT_MS;
const startTime = Date.now();
let warned = false;
console.log(`[global-setup] Waiting for server at ${healthURL} ...`);
@@ -34,18 +31,6 @@ async function waitForServer(baseURL: string): Promise<void> {
} catch {
// Connection refused / DNS error — server not up yet.
}
if (!warned && Date.now() - startTime >= PREFLIGHT_WARN_AFTER_MS) {
warned = true;
console.warn(
`[global-setup] ⚠ Still waiting for server after ${
PREFLIGHT_WARN_AFTER_MS / 1000
}s.\n` +
` Please verify that both the backend and frontend are running.\n` +
` You can start them with: ods compose dev`
);
}
await new Promise((r) => setTimeout(r, PREFLIGHT_POLL_INTERVAL_MS));
}
@@ -54,7 +39,7 @@ async function waitForServer(baseURL: string): Promise<void> {
`Timed out after ${
PREFLIGHT_TIMEOUT_MS / 1000
}s waiting for ${healthURL} to return 200. ` +
`Make sure the backend and frontend are running (e.g. \`ods compose dev\`).`
`Make sure the server is running (e.g. \`ods compose dev\`).`
);
}

View File

@@ -496,7 +496,10 @@ test.describe("Default Assistant MCP Integration", () => {
await page.getByTestId("AppSidebar/more-agents").click();
await page.waitForURL("**/app/agents");
await page.getByLabel("AgentsPage/new-agent-button").click();
await page
.getByTestId("AgentsPage/new-agent-button")
.getByRole("link", { name: "New Agent" })
.click();
await page.waitForURL("**/app/agents/create");
const assistantName = `MCP Assistant ${Date.now()}`;

View File

@@ -20,7 +20,7 @@ export async function createAssistant(page: Page, params: AssistantParams) {
// Open Assistants modal/list
await page.getByTestId("AppSidebar/more-agents").click();
await page.getByLabel("AgentsPage/new-agent-button").click();
await page.getByTestId("AgentsPage/new-agent-button").click();
// Fill required fields
await page.locator('input[name="name"]').fill(name);