mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-24 11:15:47 +00:00
Compare commits
15 Commits
ci_script
...
jamison/ag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
088c7804a8 | ||
|
|
cd2d86bb6a | ||
|
|
058c5ea494 | ||
|
|
3cb6ec2f85 | ||
|
|
691eebf00a | ||
|
|
905b6633e6 | ||
|
|
fd088196ff | ||
|
|
cafbf5b8be | ||
|
|
1235181559 | ||
|
|
caa2e45632 | ||
|
|
9c62e03120 | ||
|
|
0937305064 | ||
|
|
e4c06570e3 | ||
|
|
78fc7c86d7 | ||
|
|
84d3aea847 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR body and check whether the helper checkbox is checked.
|
||||
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
@@ -21,15 +21,14 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, NamedTuple
|
||||
from typing import NamedTuple
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
|
||||
return script.get_current_head()
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: List[str], head_rev: str
|
||||
) -> List[str]:
|
||||
"""Return only schemas whose current alembic version is not at head."""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Find which schemas actually have an alembic_version table
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"SELECT table_schema FROM information_schema.tables "
|
||||
"WHERE table_name = 'alembic_version' "
|
||||
"AND table_schema = ANY(:schemas)"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
schemas_with_table = set(row[0] for row in rows)
|
||||
|
||||
# Schemas without the table definitely need migration
|
||||
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
|
||||
|
||||
if not schemas_with_table:
|
||||
return needs_migration
|
||||
|
||||
# Validate schema names before interpolating into SQL
|
||||
for schema in schemas_with_table:
|
||||
if not is_valid_schema_name(schema):
|
||||
raise ValueError(f"Invalid schema name: {schema}")
|
||||
|
||||
# Single query to get every schema's current revision at once.
|
||||
# Use integer tags instead of interpolating schema names into
|
||||
# string literals to avoid quoting issues.
|
||||
schema_list = list(schemas_with_table)
|
||||
union_parts = [
|
||||
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
|
||||
for i, schema in enumerate(schema_list)
|
||||
]
|
||||
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
|
||||
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
|
||||
|
||||
needs_migration.extend(
|
||||
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
|
||||
)
|
||||
|
||||
return needs_migration
|
||||
|
||||
|
||||
def run_migrations_parallel(
|
||||
schemas: list[str],
|
||||
max_workers: int,
|
||||
|
||||
@@ -127,9 +127,14 @@ class ScimDAL(DAL):
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
@@ -248,11 +253,11 @@ class ScimDAL(DAL):
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
A tuple of (list of (user, mapping) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
@@ -292,33 +297,104 @@ class ScimDAL(DAL):
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
).all()
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
# Batch-fetch SCIM mappings to avoid N+1 queries
|
||||
mapping_map = self._get_user_mappings_batch([u.id for u in users])
|
||||
return [(u, mapping_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
def sync_user_external_id(
|
||||
self,
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
else:
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, ScimUserMapping]:
|
||||
"""Batch-fetch SCIM user mappings keyed by user ID."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
return {m.user_id: m for m in mappings}
|
||||
|
||||
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
|
||||
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
|
||||
|
||||
Excludes groups marked for deletion.
|
||||
"""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
|
||||
).all()
|
||||
|
||||
group_ids = [r.user_group_id for r in rels]
|
||||
if not group_ids:
|
||||
return []
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
return [(g.id, g.name) for g in groups]
|
||||
|
||||
def get_users_groups_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Batch-fetch group memberships for multiple users.
|
||||
|
||||
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
|
||||
Avoids N+1 queries when building user list responses.
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
group_ids = list({r.user_group_id for r in rels})
|
||||
if not group_ids:
|
||||
return {}
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
groups_by_id = {g.id: g.name for g in groups}
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {}
|
||||
for r in rels:
|
||||
if r.user_id and r.user_group_id in groups_by_id:
|
||||
result.setdefault(r.user_id, []).append(
|
||||
(r.user_group_id, groups_by_id[r.user_group_id])
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
@@ -483,9 +559,13 @@ class ScimDAL(DAL):
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
@@ -504,9 +584,13 @@ class ScimDAL(DAL):
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
num_hits: int = 30
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@@ -26,12 +26,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -41,6 +39,8 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -53,7 +53,6 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
@@ -63,6 +62,18 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
"""Resolve the SCIM provider for the current request.
|
||||
|
||||
Currently returns OktaProvider for all requests. When multi-provider
|
||||
support is added (ENG-3652), this will resolve based on token metadata
|
||||
or tenant configuration — no endpoint changes required.
|
||||
"""
|
||||
return get_default_provider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -100,28 +111,6 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -155,9 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
)
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or name.formatted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -171,6 +161,7 @@ def list_users(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
@@ -183,12 +174,19 @@ def list_users(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
@@ -203,6 +201,7 @@ def list_users(
|
||||
def get_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
@@ -215,20 +214,26 @@ def get_user(
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip().lower()
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
@@ -264,11 +269,14 @@ def create_user(
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, external_id)
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -276,6 +284,7 @@ def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
@@ -293,19 +302,27 @@ def replace_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip().lower(),
|
||||
email=user_resource.userName.strip(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, new_external_id)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
@@ -313,6 +330,7 @@ def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
@@ -330,11 +348,19 @@ def patch_user(
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
|
||||
current = _user_to_scim(user, external_id)
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
patched = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
@@ -345,22 +371,40 @@ def patch_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Track the scim_username — if userName was patched, update it
|
||||
new_scim_username = patched.userName.strip() if patched.userName else None
|
||||
|
||||
# If displayName was explicitly patched (different from the original), use
|
||||
# it as personal_name directly. Otherwise, derive from name components.
|
||||
personal_name: str | None
|
||||
if patched.displayName and patched.displayName != current.displayName:
|
||||
personal_name = patched.displayName
|
||||
else:
|
||||
personal_name = _scim_name_to_str(patched.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
patched.userName.strip()
|
||||
if patched.userName.strip().lower() != user.email.lower()
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
dal.sync_user_external_id(
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
@@ -398,24 +442,6 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
@@ -474,6 +500,7 @@ def list_groups(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
@@ -491,7 +518,7 @@ def list_groups(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
@@ -507,6 +534,7 @@ def list_groups(
|
||||
def get_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
@@ -521,13 +549,16 @@ def get_group(
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
return provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
@@ -565,7 +596,7 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -573,6 +604,7 @@ def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
@@ -595,7 +627,7 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -603,6 +635,7 @@ def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
@@ -621,11 +654,11 @@ def patch_group(
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
current = provider.build_group_resource(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
@@ -652,7 +685,7 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
|
||||
@@ -63,6 +63,13 @@ class ScimMeta(BaseModel):
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserGroupRef(BaseModel):
|
||||
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
|
||||
|
||||
value: str
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -76,8 +83,10 @@ class ScimUserResource(BaseModel):
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
displayName: str | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
@@ -121,12 +130,40 @@ class ScimPatchOperationType(str, Enum):
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchResourceValue(BaseModel):
|
||||
"""Partial resource dict for path-less PATCH replace operations.
|
||||
|
||||
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
|
||||
of resource attributes to set. IdPs may include read-only fields
|
||||
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
|
||||
stripped by the provider's ``ignored_patch_paths`` before processing.
|
||||
|
||||
``extra="allow"`` lets unknown attributes pass through so the patch
|
||||
handler can decide what to do with them (ignore or reject).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
active: bool | None = None
|
||||
userName: str | None = None
|
||||
displayName: str | None = None
|
||||
externalId: str | None = None
|
||||
name: ScimName | None = None
|
||||
members: list[ScimGroupMember] | None = None
|
||||
id: str | None = None
|
||||
schemas: list[str] | None = None
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
|
||||
@@ -16,9 +16,12 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
@@ -41,9 +44,15 @@ _MEMBER_FILTER_RE = re.compile(
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
@@ -55,9 +64,9 @@ def apply_user_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
@@ -71,30 +80,34 @@ def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
@@ -107,7 +120,7 @@ def _set_user_field(
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
@@ -116,9 +129,15 @@ def _set_user_field(
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current group resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
@@ -133,7 +152,9 @@ def apply_group_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
_apply_group_replace(
|
||||
op, data, current_members, added_ids, removed_ids, ignored_paths
|
||||
)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
@@ -154,38 +175,48 @@ def _apply_group_replace(
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
dumped = op.value.model_dump(exclude_unset=True)
|
||||
for key, val in dumped.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
_set_group_field(key.lower(), val, data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
_replace_members(
|
||||
_members_to_dicts(op.value), current_members, added_ids, removed_ids
|
||||
)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
_set_group_field(path, op.value, data, ignored_paths)
|
||||
|
||||
|
||||
def _members_to_dicts(
|
||||
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
|
||||
) -> list[dict]:
|
||||
"""Convert a member list value to a list of dicts for internal processing."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
return [m.model_dump(exclude_none=True) for m in value]
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
value: list[dict],
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
@@ -197,11 +228,14 @@ def _replace_members(
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
@@ -223,8 +257,10 @@ def _apply_group_add(
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
for member_data in member_dicts:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
|
||||
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
123
backend/ee/onyx/server/scim/providers/base.py
Normal file
123
backend/ee/onyx/server/scim/providers/base.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Base SCIM provider abstraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
Subclass this to handle IdP-specific quirks. The base class provides
|
||||
RFC 7643-compliant response builders that populate all standard fields.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Short identifier for this provider (e.g. ``"okta"``)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
|
||||
|
||||
IdPs may include read-only or meta fields alongside actual changes
|
||||
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
|
||||
here are silently dropped instead of raising an error.
|
||||
"""
|
||||
...
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
Args:
|
||||
user: The Onyx user model.
|
||||
external_id: The IdP's external identifier for this user.
|
||||
groups: List of ``(group_id, group_name)`` tuples for the
|
||||
``groups`` read-only attribute. Pass ``None`` or ``[]``
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
"""
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=self._build_scim_name(user),
|
||||
displayName=user.personal_name,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Build a SCIM Group response from an Onyx UserGroup."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
Currently returns ``OktaProvider`` since Okta is the primary supported
|
||||
IdP. When provider detection is added (via token metadata or tenant
|
||||
config), this can be replaced with dynamic resolution.
|
||||
"""
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
return OktaProvider()
|
||||
25
backend/ee/onyx/server/scim/providers/okta.py
Normal file
25
backend/ee/onyx/server/scim/providers/okta.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Okta SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
class OktaProvider(ScimProvider):
|
||||
"""Okta SCIM provider.
|
||||
|
||||
Okta behavioral notes:
|
||||
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
|
||||
- Sends path-less PATCH with value dicts containing extra fields
|
||||
(``id``, ``schemas``)
|
||||
- Expects ``displayName`` and ``groups`` in user responses
|
||||
- Only uses ``eq`` operator for ``userName`` filter
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "okta"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
@@ -277,13 +277,32 @@ def verify_email_domain(email: str) -> None:
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
|
||||
@@ -59,12 +59,11 @@ def _build_index_filters(
|
||||
|
||||
base_filters = user_provided_filters or BaseFilters()
|
||||
|
||||
if (
|
||||
user_provided_filters
|
||||
and user_provided_filters.document_set is None
|
||||
and persona_document_sets is not None
|
||||
):
|
||||
base_filters.document_set = persona_document_sets
|
||||
document_set_filter = (
|
||||
base_filters.document_set
|
||||
if base_filters.document_set is not None
|
||||
else persona_document_sets
|
||||
)
|
||||
|
||||
time_filter = base_filters.time_cutoff or persona_time_cutoff
|
||||
source_filter = base_filters.source_type
|
||||
@@ -120,7 +119,7 @@ def _build_index_filters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
source_type=source_filter,
|
||||
document_set=persona_document_sets,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
tags=base_filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
|
||||
@@ -1,11 +1,102 @@
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: list[str], head_rev: str
|
||||
) -> list[str]:
|
||||
"""Return only schemas whose current alembic version is not at head.
|
||||
|
||||
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
|
||||
into a temp table one at a time. This avoids building a massive UNION ALL
|
||||
query (which locks the DB and times out at 17k+ schemas) and instead
|
||||
acquires locks sequentially, one schema per iteration.
|
||||
"""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Populate a temp input table with exactly the schemas we care about.
|
||||
# The DO block reads from this table so it only iterates the requested
|
||||
# schemas instead of every tenant_% schema in the database.
|
||||
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
|
||||
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO _tenant_schemas_input (schema_name) "
|
||||
"SELECT unnest(CAST(:schemas AS text[]))"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE TEMP TABLE _alembic_version_snapshot "
|
||||
"(schema_name text, version_num text)"
|
||||
)
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
s text;
|
||||
schemas text[];
|
||||
BEGIN
|
||||
SELECT array_agg(schema_name) INTO schemas
|
||||
FROM _tenant_schemas_input;
|
||||
|
||||
IF schemas IS NULL THEN
|
||||
RAISE NOTICE 'No tenant schemas found.';
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
FOREACH s IN ARRAY schemas LOOP
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
'INSERT INTO _alembic_version_snapshot
|
||||
SELECT %L, version_num FROM %I.alembic_version',
|
||||
s, s
|
||||
);
|
||||
EXCEPTION
|
||||
-- undefined_table: schema exists but has no alembic_version
|
||||
-- table yet (new tenant, not yet migrated).
|
||||
-- invalid_schema_name: tenant is registered but its
|
||||
-- PostgreSQL schema does not exist yet (e.g. provisioning
|
||||
-- incomplete). Both cases mean no version is available and
|
||||
-- the schema will be included in the migration list.
|
||||
WHEN undefined_table THEN NULL;
|
||||
WHEN invalid_schema_name THEN NULL;
|
||||
END;
|
||||
END LOOP;
|
||||
END;
|
||||
$$
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
rows = conn.execute(
|
||||
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
|
||||
)
|
||||
version_by_schema = {row[0]: row[1] for row in rows}
|
||||
|
||||
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
|
||||
|
||||
# Schemas missing from the snapshot have no alembic_version table yet and
|
||||
# also need migration. version_by_schema.get(s) returns None for those,
|
||||
# and None != head_rev, so they are included automatically.
|
||||
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
|
||||
@@ -554,10 +554,9 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the rerank-count value set in the
|
||||
# Vespa schema config. Otherwise we would be getting fewer results than
|
||||
# expected for reranking.
|
||||
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
|
||||
# Avoid over-fetching a very large candidate set for global-phase reranking.
|
||||
# Keep enough headroom for quality while capping cost on larger indices.
|
||||
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
|
||||
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self._index_name)
|
||||
|
||||
@@ -3,8 +3,8 @@ set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -55,10 +55,6 @@ else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
|
||||
# Start the Code Interpreter container
|
||||
echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@@ -46,11 +47,15 @@ def mock_current_admin_user() -> MagicMock:
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client() -> Generator[TestClient, None, None]:
|
||||
# Initialize TestClient with the FastAPI app using a no-op test lifespan
|
||||
# Initialize TestClient with the FastAPI app using a no-op test lifespan.
|
||||
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
|
||||
# CollectorRegistry" errors when multiple tests each create a new app
|
||||
# (prometheus registers metrics globally and rejects duplicate names).
|
||||
get_app = fetch_versioned_implementation(
|
||||
module="onyx.main", attribute="get_application"
|
||||
)
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
with patch("onyx.main.setup_prometheus_metrics"):
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
|
||||
# Override the database session dependency with a mock
|
||||
# (these tests don't actually need DB access)
|
||||
|
||||
@@ -106,13 +106,13 @@ class TestGuildDataIsolation:
|
||||
|
||||
# Create admin user for tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_admin1+{unique}@example.com",
|
||||
email=f"discord_admin1_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create admin user for tenant 2
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_admin2+{unique}@example.com",
|
||||
email=f"discord_admin2_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
@@ -170,10 +170,10 @@ class TestGuildDataIsolation:
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_list1+{unique}@example.com",
|
||||
email=f"discord_list1_{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_list2+{unique}@example.com",
|
||||
email=f"discord_list2_{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create 1 guild in tenant 1
|
||||
@@ -350,10 +350,10 @@ class TestGuildAccessIsolation:
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_access1+{unique}@example.com",
|
||||
email=f"discord_access1_{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_access2+{unique}@example.com",
|
||||
email=f"discord_access2_{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create a guild in tenant 1
|
||||
|
||||
@@ -21,7 +21,7 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None: # noqa: ARG00
|
||||
|
||||
# Admin user invites the previously registered and non-registered user
|
||||
UserManager.invite_user(invited_user.email, admin_user)
|
||||
UserManager.invite_user(f"{INVITED_BASIC_USER}+{unique}@example.com", admin_user)
|
||||
UserManager.invite_user(f"{INVITED_BASIC_USER}_{unique}@example.com", admin_user)
|
||||
|
||||
# Verify users are in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
@@ -40,7 +40,7 @@ def test_non_registered_user_gets_basic_role(
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Admin user invites a non-registered user
|
||||
invited_email = f"{INVITED_BASIC_USER}+{unique}@example.com"
|
||||
invited_email = f"{INVITED_BASIC_USER}_{unique}@example.com"
|
||||
UserManager.invite_user(invited_email, admin_user)
|
||||
|
||||
# Non-registered user registers
|
||||
@@ -58,7 +58,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None: # noqa: A
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Create a user to be invited
|
||||
invited_user_email = f"invited_user+{unique}@example.com"
|
||||
invited_user_email = f"invited_user_{unique}@example.com"
|
||||
|
||||
# User registers with the same email as the invitation
|
||||
invited_user: DATestUser = UserManager.create(
|
||||
|
||||
@@ -20,13 +20,13 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]: # noqa: ARG0
|
||||
unique = uuid4().hex
|
||||
# Creating an admin user for Tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"admin+{unique}@example.com",
|
||||
email=f"admin_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"admin2+{unique}@example.com",
|
||||
email=f"admin2_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Integration tests for onyx.db.engine.tenant_utils.get_schemas_needing_migration.
|
||||
|
||||
These tests require a live database and exercise the function directly,
|
||||
independent of the alembic migration runner script.
|
||||
|
||||
Usage:
|
||||
pytest tests/integration/multitenant_tests/test_get_schemas_needing_migration.py -v
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
|
||||
_BACKEND_DIR = __file__[: __file__.index("/tests/")]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_head_rev() -> str:
|
||||
result = subprocess.run(
|
||||
["alembic", "heads", "--resolve-dependencies"],
|
||||
cwd=_BACKEND_DIR,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
|
||||
rev = result.stdout.strip().split()[0]
|
||||
assert len(rev) > 0
|
||||
return rev
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_at_head(
|
||||
engine: Engine, current_head_rev: str
|
||||
) -> Generator[str, None, None]:
|
||||
"""Tenant schema with alembic_version already at head — should be excluded."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{schema}".alembic_version '
|
||||
f"(version_num VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(f'INSERT INTO "{schema}".alembic_version (version_num) VALUES (:rev)'),
|
||||
{"rev": current_head_rev},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
"""Tenant schema with no tables — should be included (needs migration)."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_stale_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
"""Tenant schema with a non-head revision — should be included (needs migration)."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{schema}".alembic_version '
|
||||
f"(version_num VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
f'INSERT INTO "{schema}".alembic_version (version_num) '
|
||||
f"VALUES ('stalerev000000000000')"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_classifies_all_cases(
|
||||
current_head_rev: str,
|
||||
tenant_schema_at_head: str,
|
||||
tenant_schema_empty: str,
|
||||
tenant_schema_stale_rev: str,
|
||||
) -> None:
|
||||
"""Correctly classifies all three schema states:
|
||||
- at head → excluded
|
||||
- no table → included (needs migration)
|
||||
- stale rev → included (needs migration)
|
||||
"""
|
||||
all_schemas = [tenant_schema_at_head, tenant_schema_empty, tenant_schema_stale_rev]
|
||||
result = get_schemas_needing_migration(all_schemas, current_head_rev)
|
||||
|
||||
assert tenant_schema_at_head not in result
|
||||
assert tenant_schema_empty in result
|
||||
assert tenant_schema_stale_rev in result
|
||||
|
||||
|
||||
def test_idempotent(
|
||||
current_head_rev: str,
|
||||
tenant_schema_at_head: str,
|
||||
tenant_schema_empty: str,
|
||||
) -> None:
|
||||
"""Calling the function twice returns the same result.
|
||||
|
||||
Verifies that the DROP TABLE IF EXISTS guards correctly clean up temp
|
||||
tables so a second call succeeds even if the first left state behind.
|
||||
"""
|
||||
schemas = [tenant_schema_at_head, tenant_schema_empty]
|
||||
|
||||
first = get_schemas_needing_migration(schemas, current_head_rev)
|
||||
second = get_schemas_needing_migration(schemas, current_head_rev)
|
||||
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_empty_input(current_head_rev: str) -> None:
|
||||
"""An empty input list returns immediately without touching the DB."""
|
||||
assert get_schemas_needing_migration([], current_head_rev) == []
|
||||
@@ -3,6 +3,7 @@ from fastapi import HTTPException
|
||||
|
||||
import onyx.auth.users as users
|
||||
from onyx.auth.users import verify_email_domain
|
||||
from onyx.configs.constants import AuthType
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_case_insensitive_match(
|
||||
@@ -35,3 +36,37 @@ def test_verify_email_domain_invalid_email_format(
|
||||
verify_email_domain("userexample.com") # missing '@'
|
||||
assert exc.value.status_code == 400
|
||||
assert "Email is not valid" in exc.value.detail
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_plus_addressing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verify_email_domain("user+tag@gmail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "'+'" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_plus_for_onyx_app(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
# Should not raise for onyx.app domain
|
||||
verify_email_domain("user+tag@onyx.app")
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_googlemail(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verify_email_domain("user@googlemail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "gmail.com" in str(exc.value.detail)
|
||||
|
||||
@@ -97,6 +97,7 @@ class TestScimDALUserMappings:
|
||||
assert model_attrs(added_obj) == {
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
@@ -15,7 +15,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
@@ -35,6 +38,12 @@ def mock_token() -> MagicMock:
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider() -> ScimProvider:
|
||||
"""An OktaProvider instance for endpoint tests."""
|
||||
return OktaProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dal() -> Generator[MagicMock, None, None]:
|
||||
"""Patch ScimDAL construction in api module and yield the mock instance."""
|
||||
@@ -53,6 +62,9 @@ def mock_dal() -> Generator[MagicMock, None, None]:
|
||||
dal.get_group_mapping_by_external_id.return_value = None
|
||||
dal.get_group_members.return_value = []
|
||||
dal.list_groups.return_value = ([], 0)
|
||||
# User-group relationship defaults
|
||||
dal.get_user_groups.return_value = []
|
||||
dal.get_users_groups_batch.return_value = {}
|
||||
yield dal
|
||||
|
||||
|
||||
@@ -96,6 +108,16 @@ def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
return group
|
||||
|
||||
|
||||
def make_user_mapping(**kwargs: Any) -> MagicMock:
|
||||
"""Build a mock ScimUserMapping ORM object with configurable attributes."""
|
||||
mapping = MagicMock(spec=ScimUserMapping)
|
||||
mapping.id = kwargs.get("id", 1)
|
||||
mapping.external_id = kwargs.get("external_id", "ext-default")
|
||||
mapping.user_id = kwargs.get("user_id", uuid4())
|
||||
mapping.scim_username = kwargs.get("scim_username", None)
|
||||
return mapping
|
||||
|
||||
|
||||
def assert_scim_error(result: object, expected_status: int) -> None:
|
||||
"""Assert *result* is a JSONResponse with the given status code."""
|
||||
assert isinstance(result, JSONResponse)
|
||||
|
||||
@@ -21,6 +21,7 @@ from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_group
|
||||
@@ -34,6 +35,7 @@ class TestListGroups:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_groups.return_value = ([], 0)
|
||||
|
||||
@@ -42,6 +44,7 @@ class TestListGroups:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -54,6 +57,7 @@ class TestListGroups:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_groups.side_effect = ValueError(
|
||||
"Unsupported filter attribute: userName"
|
||||
@@ -64,6 +68,7 @@ class TestListGroups:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -74,6 +79,7 @@ class TestListGroups:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Engineering")
|
||||
uid = uuid4()
|
||||
@@ -85,6 +91,7 @@ class TestListGroups:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -106,6 +113,7 @@ class TestGetGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -114,6 +122,7 @@ class TestGetGroup:
|
||||
result = get_group(
|
||||
group_id="5",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -126,10 +135,12 @@ class TestGetGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
result = get_group(
|
||||
group_id="not-a-number",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -140,12 +151,14 @@ class TestGetGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
result = get_group(
|
||||
group_id="999",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -162,6 +175,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], None)
|
||||
@@ -172,6 +186,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -185,6 +200,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = make_db_group()
|
||||
resource = make_scim_group()
|
||||
@@ -192,6 +208,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -204,6 +221,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], "Invalid member ID: bad-uuid")
|
||||
@@ -213,6 +231,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -225,6 +244,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
uid = uuid4()
|
||||
@@ -235,6 +255,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -247,6 +268,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], None)
|
||||
@@ -257,6 +279,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -274,6 +297,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Old Name")
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -286,6 +310,7 @@ class TestReplaceGroup:
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -299,6 +324,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
@@ -306,6 +332,7 @@ class TestReplaceGroup:
|
||||
group_id="999",
|
||||
group_resource=make_scim_group(),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -318,6 +345,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -329,6 +357,7 @@ class TestReplaceGroup:
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -341,6 +370,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -353,6 +383,7 @@ class TestReplaceGroup:
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -369,6 +400,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Old Name")
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -391,6 +423,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -402,6 +435,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
@@ -419,6 +453,7 @@ class TestPatchGroup:
|
||||
group_id="999",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -431,6 +466,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -452,6 +488,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -464,6 +501,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -483,7 +521,7 @@ class TestPatchGroup:
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.ADD,
|
||||
path="members",
|
||||
value=[{"value": uid}],
|
||||
value=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -492,6 +530,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -506,6 +545,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -525,7 +565,7 @@ class TestPatchGroup:
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.ADD,
|
||||
path="members",
|
||||
value=[{"value": str(uid)}],
|
||||
value=[ScimGroupMember(value=str(uid))],
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -534,6 +574,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -546,6 +587,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -568,6 +610,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,13 +2,19 @@ import pytest
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
@@ -29,14 +35,14 @@ def _make_group(**kwargs: object) -> ScimGroupResource:
|
||||
|
||||
def _replace_op(
|
||||
path: str | None = None,
|
||||
value: str | bool | dict | list | None = None,
|
||||
value: ScimPatchValue = None,
|
||||
) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.REPLACE, path=path, value=value)
|
||||
|
||||
|
||||
def _add_op(
|
||||
path: str | None = None,
|
||||
value: str | bool | dict | list | None = None,
|
||||
value: ScimPatchValue = None,
|
||||
) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.ADD, path=path, value=value)
|
||||
|
||||
@@ -80,7 +86,12 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, {"active": False, "userName": "new@example.com"})],
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(active=False, userName="new@example.com"),
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert result.active is False
|
||||
@@ -119,6 +130,86 @@ class TestApplyUserPatch:
|
||||
with pytest.raises(ScimPatchError, match="Unsupported operation"):
|
||||
apply_user_patch([_remove_op("active")], user)
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
"""The 'schemas' key in a value dict should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
active=False,
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
),
|
||||
)
|
||||
],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
|
||||
def test_okta_deactivation_payload(self) -> None:
|
||||
"""Exact Okta deactivation payload: path-less replace with id + active."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(id="abc-123", active=False),
|
||||
)
|
||||
],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_replace_displayname(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op("displayName", "New Display Name")], user
|
||||
)
|
||||
assert result.displayName == "New Display Name"
|
||||
assert result.name is not None
|
||||
assert result.name.formatted == "New Display Name"
|
||||
|
||||
def test_replace_without_path_complex_value_dict(self) -> None:
|
||||
"""Okta sends id/schemas/meta alongside actual changes — complex types
|
||||
(lists, nested dicts) must not cause Pydantic validation errors."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
active=False,
|
||||
id="some-uuid",
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
),
|
||||
)
|
||||
],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_add_operation_works_like_replace(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
assert result.externalId == "ext-456"
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
@@ -135,7 +226,12 @@ class TestApplyGroupPatch:
|
||||
def test_add_members(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, removed = apply_group_patch(
|
||||
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
|
||||
[
|
||||
_add_op(
|
||||
"members",
|
||||
[ScimGroupMember(value="user-1"), ScimGroupMember(value="user-2")],
|
||||
)
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
@@ -145,7 +241,7 @@ class TestApplyGroupPatch:
|
||||
def test_add_members_without_path(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, _ = apply_group_patch(
|
||||
[_add_op(None, [{"value": "user-1"}])],
|
||||
[_add_op(None, [ScimGroupMember(value="user-1")])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 1
|
||||
@@ -154,7 +250,12 @@ class TestApplyGroupPatch:
|
||||
def test_add_duplicate_member_skipped(self) -> None:
|
||||
group = _make_group(members=[ScimGroupMember(value="user-1")])
|
||||
result, added, _ = apply_group_patch(
|
||||
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
|
||||
[
|
||||
_add_op(
|
||||
"members",
|
||||
[ScimGroupMember(value="user-1"), ScimGroupMember(value="user-2")],
|
||||
)
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
@@ -190,7 +291,7 @@ class TestApplyGroupPatch:
|
||||
result, added, removed = apply_group_patch(
|
||||
[
|
||||
_replace_op("displayName", "Renamed"),
|
||||
_add_op("members", [{"value": "user-2"}]),
|
||||
_add_op("members", [ScimGroupMember(value="user-2")]),
|
||||
_remove_op('members[value eq "user-1"]'),
|
||||
],
|
||||
group,
|
||||
@@ -221,7 +322,12 @@ class TestApplyGroupPatch:
|
||||
]
|
||||
)
|
||||
result, added, removed = apply_group_patch(
|
||||
[_replace_op("members", [{"value": "user-2"}, {"value": "user-3"}])],
|
||||
[
|
||||
_replace_op(
|
||||
"members",
|
||||
[ScimGroupMember(value="user-2"), ScimGroupMember(value="user-3")],
|
||||
)
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
@@ -256,3 +362,55 @@ class TestApplyGroupPatch:
|
||||
group = _make_group()
|
||||
apply_group_patch([_replace_op("displayName", "Changed")], group)
|
||||
assert group.displayName == "Engineering"
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Group replace with 'id' in value dict should be silently ignored."""
|
||||
group = _make_group()
|
||||
result, _, _ = apply_group_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None, ScimPatchResourceValue(displayName="Updated", id="some-id")
|
||||
)
|
||||
],
|
||||
group,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.displayName == "Updated"
|
||||
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
group = _make_group()
|
||||
result, _, _ = apply_group_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
displayName="Updated",
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
),
|
||||
)
|
||||
],
|
||||
group,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.displayName == "Updated"
|
||||
|
||||
def test_replace_without_path_complex_value_dict(self) -> None:
|
||||
"""Group PATCH with complex types in value dict (lists, nested dicts)
|
||||
must not cause Pydantic validation errors."""
|
||||
group = _make_group()
|
||||
result, _, _ = apply_group_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
displayName="Updated",
|
||||
id="123",
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
),
|
||||
)
|
||||
],
|
||||
group,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.displayName == "Updated"
|
||||
|
||||
167
backend/tests/unit/onyx/server/scim/test_providers.py
Normal file
167
backend/tests/unit/onyx/server/scim/test_providers.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
|
||||
def _make_mock_user(
|
||||
user_id: UUID | None = None,
|
||||
email: str = "test@example.com",
|
||||
personal_name: str | None = "Test User",
|
||||
is_active: bool = True,
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.personal_name = personal_name
|
||||
user.is_active = is_active
|
||||
return user
|
||||
|
||||
|
||||
def _make_mock_group(group_id: int = 42, name: str = "Engineering") -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
return group
|
||||
|
||||
|
||||
class TestOktaProvider:
|
||||
def test_name(self) -> None:
|
||||
assert OktaProvider().name == "okta"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
assert OktaProvider().ignored_patch_paths == frozenset(
|
||||
{"id", "schemas", "meta"}
|
||||
)
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
|
||||
assert result == ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId="ext-123",
|
||||
userName="test@example.com",
|
||||
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
|
||||
displayName="Test User",
|
||||
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
|
||||
active=True,
|
||||
groups=[],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def test_build_user_resource_with_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
groups = [(1, "Engineering"), (2, "Design")]
|
||||
result = provider.build_user_resource(user, "ext-123", groups=groups)
|
||||
|
||||
assert result.groups == [
|
||||
ScimUserGroupRef(value="1", display="Engineering"),
|
||||
ScimUserGroupRef(value="2", display="Design"),
|
||||
]
|
||||
|
||||
def test_build_user_resource_empty_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123", groups=[])
|
||||
|
||||
assert result.groups == []
|
||||
|
||||
def test_build_user_resource_no_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
|
||||
assert result.groups == []
|
||||
|
||||
def test_build_user_resource_name_parsing(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(personal_name="Jane Doe")
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Jane", familyName="Doe", formatted="Jane Doe"
|
||||
)
|
||||
|
||||
def test_build_user_resource_single_name(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(personal_name="Madonna")
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Madonna", familyName=None, formatted="Madonna"
|
||||
)
|
||||
|
||||
def test_build_user_resource_no_name(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(personal_name=None)
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name is None
|
||||
assert result.displayName is None
|
||||
|
||||
def test_build_user_resource_scim_username_preserves_case(self) -> None:
|
||||
"""When scim_username is set, userName and emails use original case."""
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(email="alice@example.com")
|
||||
result = provider.build_user_resource(
|
||||
user, "ext-1", scim_username="Alice@Example.com"
|
||||
)
|
||||
|
||||
assert result.userName == "Alice@Example.com"
|
||||
assert result.emails[0].value == "Alice@Example.com"
|
||||
|
||||
def test_build_user_resource_scim_username_none_falls_back(self) -> None:
|
||||
"""When scim_username is None, userName falls back to user.email."""
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(email="alice@example.com")
|
||||
result = provider.build_user_resource(user, "ext-1", scim_username=None)
|
||||
|
||||
assert result.userName == "alice@example.com"
|
||||
assert result.emails[0].value == "alice@example.com"
|
||||
|
||||
def test_build_group_resource(self) -> None:
|
||||
provider = OktaProvider()
|
||||
group = _make_mock_group()
|
||||
uid1, uid2 = uuid4(), uuid4()
|
||||
members: list[tuple[UUID, str | None]] = [
|
||||
(uid1, "alice@example.com"),
|
||||
(uid2, "bob@example.com"),
|
||||
]
|
||||
|
||||
result = provider.build_group_resource(group, members, "ext-g-1")
|
||||
|
||||
assert result == ScimGroupResource(
|
||||
id="42",
|
||||
externalId="ext-g-1",
|
||||
displayName="Engineering",
|
||||
members=[
|
||||
ScimGroupMember(value=str(uid1), display="alice@example.com"),
|
||||
ScimGroupMember(value=str(uid2), display="bob@example.com"),
|
||||
],
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
def test_build_group_resource_empty_members(self) -> None:
|
||||
provider = OktaProvider()
|
||||
group = _make_mock_group()
|
||||
result = provider.build_group_resource(group, [])
|
||||
|
||||
assert result.members == []
|
||||
|
||||
|
||||
class TestGetDefaultProvider:
|
||||
def test_returns_okta(self) -> None:
|
||||
provider = get_default_provider()
|
||||
assert isinstance(provider, OktaProvider)
|
||||
@@ -9,6 +9,7 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -22,9 +23,11 @@ from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
@@ -35,6 +38,7 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
@@ -43,6 +47,7 @@ class TestListUsers:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -55,15 +60,20 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@example.com", personal_name="Alice Smith")
|
||||
mock_dal.list_users.return_value = ([(user, "ext-abc")], 1)
|
||||
mapping = make_user_mapping(
|
||||
external_id="ext-abc", user_id=user.id, scim_username="Alice@example.com"
|
||||
)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -72,7 +82,7 @@ class TestListUsers:
|
||||
assert len(result.Resources) == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert resource.userName == "alice@example.com"
|
||||
assert resource.userName == "Alice@example.com"
|
||||
assert resource.externalId == "ext-abc"
|
||||
|
||||
def test_unsupported_filter_attribute_returns_400(
|
||||
@@ -80,6 +90,7 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_users.side_effect = ValueError(
|
||||
"Unsupported filter attribute: emails"
|
||||
@@ -90,6 +101,7 @@ class TestListUsers:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -100,12 +112,14 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
result = list_users(
|
||||
filter="not a valid filter",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -120,6 +134,7 @@ class TestGetUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -127,6 +142,7 @@ class TestGetUser:
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -139,10 +155,12 @@ class TestGetUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
result = get_user(
|
||||
user_id="not-a-uuid",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -153,12 +171,14 @@ class TestGetUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
result = get_user(
|
||||
user_id=str(uuid4()),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -175,6 +195,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="new@example.com")
|
||||
@@ -182,6 +203,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -195,12 +217,14 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -213,6 +237,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = make_db_user()
|
||||
resource = make_scim_user()
|
||||
@@ -220,6 +245,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -232,6 +258,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
mock_dal.add_user.side_effect = IntegrityError("dup", {}, Exception())
|
||||
@@ -240,6 +267,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -253,6 +281,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_seats.return_value = "Seat limit reached"
|
||||
resource = make_scim_user()
|
||||
@@ -260,6 +289,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -272,6 +302,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(externalId="ext-123")
|
||||
@@ -279,6 +310,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -295,6 +327,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -307,6 +340,7 @@ class TestReplaceUser:
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -319,6 +353,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
@@ -326,6 +361,7 @@ class TestReplaceUser:
|
||||
user_id=str(uuid4()),
|
||||
user_resource=make_scim_user(),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -338,6 +374,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=False)
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -348,6 +385,7 @@ class TestReplaceUser:
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -359,6 +397,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -369,11 +408,14 @@ class TestReplaceUser:
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(user.id, None)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(
|
||||
user.id, None, scim_username="test@example.com"
|
||||
)
|
||||
|
||||
|
||||
class TestPatchUser:
|
||||
@@ -384,6 +426,7 @@ class TestPatchUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -401,6 +444,7 @@ class TestPatchUser:
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -412,6 +456,7 @@ class TestPatchUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
patch_req = ScimPatchRequest(
|
||||
@@ -428,11 +473,45 @@ class TestPatchUser:
|
||||
user_id=str(uuid4()),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
def test_patch_displayname_persists(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH displayName should update personal_name in the DB."""
|
||||
user = make_db_user(personal_name="Old Name")
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="displayName",
|
||||
value="New Display Name",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
# Verify the update_user call received the new display name
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["personal_name"] == "New Display Name"
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_user_patch")
|
||||
def test_patch_error_returns_error_response(
|
||||
self,
|
||||
@@ -440,6 +519,7 @@ class TestPatchUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -457,6 +537,7 @@ class TestPatchUser:
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -519,3 +600,87 @@ class TestDeleteUser:
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
|
||||
class TestScimNameToStr:
|
||||
"""Tests for _scim_name_to_str helper."""
|
||||
|
||||
def test_prefers_given_family_over_formatted(self) -> None:
|
||||
"""Okta may send stale formatted while updating givenName/familyName."""
|
||||
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
|
||||
assert _scim_name_to_str(name) == "Jane Smith"
|
||||
|
||||
def test_given_name_only(self) -> None:
|
||||
name = ScimName(givenName="Jane")
|
||||
assert _scim_name_to_str(name) == "Jane"
|
||||
|
||||
def test_family_name_only(self) -> None:
|
||||
name = ScimName(familyName="Smith")
|
||||
assert _scim_name_to_str(name) == "Smith"
|
||||
|
||||
def test_falls_back_to_formatted(self) -> None:
|
||||
name = ScimName(formatted="Display Name")
|
||||
assert _scim_name_to_str(name) == "Display Name"
|
||||
|
||||
def test_none_returns_none(self) -> None:
|
||||
assert _scim_name_to_str(None) is None
|
||||
|
||||
def test_empty_name_returns_none(self) -> None:
|
||||
name = ScimName()
|
||||
assert _scim_name_to_str(name) is None
|
||||
|
||||
|
||||
class TestEmailCasePreservation:
|
||||
"""Tests verifying email case is preserved through SCIM endpoints."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_preserves_username_case(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""POST /Users with mixed-case userName returns the original case."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="Alice@Example.COM")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
def test_get_preserves_username_case(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""GET /Users/{id} returns the original-case userName from mapping."""
|
||||
user = make_db_user(email="alice@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(
|
||||
external_id="ext-1",
|
||||
user_id=user.id,
|
||||
scim_username="Alice@Example.COM",
|
||||
)
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
@@ -490,7 +490,6 @@ func createCherryPickPR(headBranch, baseBranch, title string, commitSHAs, commit
|
||||
|
||||
// Add standard checklist
|
||||
body += "\n\n"
|
||||
body += "- [x] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.\n"
|
||||
body += "- [x] [Optional] Override Linear Check\n"
|
||||
|
||||
cmd := exec.Command("gh", "pr", "create",
|
||||
|
||||
@@ -118,7 +118,7 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
// Create the CI branch
|
||||
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
|
||||
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] I have considered whether this PR needs to be cherry-picked to the latest beta branch.\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
|
||||
// Fetch the fork's branch
|
||||
if forkRepo == "" {
|
||||
|
||||
@@ -105,6 +105,18 @@ const nextConfig = {
|
||||
destination: "/app",
|
||||
permanent: true,
|
||||
},
|
||||
// NRF routes: Redirect to /nrf which doesn't require auth
|
||||
// (NRFPage handles unauthenticated users gracefully with a login modal)
|
||||
{
|
||||
source: "/app/nrf/side-panel",
|
||||
destination: "/nrf/side-panel",
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/app/nrf",
|
||||
destination: "/nrf",
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/chat/:path*",
|
||||
destination: "/app/:path*",
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState, useReducer } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
@@ -22,7 +22,6 @@ import {
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
@@ -402,36 +401,40 @@ export default function Page() {
|
||||
: undefined);
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Web Search"
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
includeDivider={false}
|
||||
title="Web Search"
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
<Callout type="danger" title="Failed to load web search settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</>
|
||||
<SettingsLayouts.Body>
|
||||
<Callout type="danger" title="Failed to load web search settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Web Search"
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
includeDivider={false}
|
||||
title="Web Search"
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
<div className="mt-8">
|
||||
<SettingsLayouts.Body>
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -827,32 +830,22 @@ export default function Page() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<>
|
||||
<AdminPageTitle icon={SvgGlobe} title="Web Search" />
|
||||
<div className="pt-4 pb-4">
|
||||
<Text as="p" className="text-text-dark">
|
||||
Search settings for external search across the internet.
|
||||
</Text>
|
||||
</div>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
title="Web Search"
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full flex-col gap-8 pb-6">
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainContentEmphasis text05>
|
||||
Search Engine
|
||||
</Text>
|
||||
<Text
|
||||
as="p"
|
||||
className="flex items-start gap-[2px] self-stretch text-text-03"
|
||||
secondaryBody
|
||||
text03
|
||||
>
|
||||
External search engine API used for web search result URLs,
|
||||
snippets, and metadata.
|
||||
</Text>
|
||||
</div>
|
||||
<SettingsLayouts.Body>
|
||||
<div className="flex w-full flex-col gap-3">
|
||||
<Content
|
||||
title="Search Engine"
|
||||
description="External search engine API used for web search result URLs, snippets, and metadata."
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
|
||||
{activationError && (
|
||||
<Callout type="danger" title="Unable to update default provider">
|
||||
@@ -974,14 +967,12 @@ export default function Page() {
|
||||
size: 16,
|
||||
isHighlighted,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text05>
|
||||
{label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
@@ -1045,20 +1036,13 @@ export default function Page() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainContentEmphasis text05>
|
||||
Web Crawler
|
||||
</Text>
|
||||
<Text
|
||||
as="p"
|
||||
className="flex items-start gap-[2px] self-stretch text-text-03"
|
||||
secondaryBody
|
||||
text03
|
||||
>
|
||||
Used to read the full contents of search result pages.
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex w-full flex-col gap-3">
|
||||
<Content
|
||||
title="Web Crawler"
|
||||
description="Used to read the full contents of search result pages."
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
|
||||
{contentActivationError && (
|
||||
<Callout type="danger" title="Unable to update crawler">
|
||||
@@ -1173,14 +1157,12 @@ export default function Page() {
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text05>
|
||||
{label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{provider.provider_type !== "onyx_web_crawler" &&
|
||||
@@ -1244,8 +1226,8 @@ export default function Page() {
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
<WebProviderSetupModal
|
||||
isOpen={selectedProviderType !== null}
|
||||
|
||||
@@ -382,9 +382,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
<IconButton
|
||||
icon={SvgMenu}
|
||||
onClick={toggleSettings}
|
||||
tertiary
|
||||
secondary
|
||||
tooltip="Open settings"
|
||||
className="bg-mask-02 backdrop-blur-[12px] rounded-full shadow-01 hover:bg-mask-03"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -75,11 +75,10 @@ export function useOnboardingModal(): OnboardingModalController {
|
||||
level: existingPersona?.level,
|
||||
};
|
||||
|
||||
// Check if user has completed initial onboarding
|
||||
// Check if user has completed initial onboarding (only role required, not name)
|
||||
const hasUserInfo = useMemo(() => {
|
||||
const existingPersona = getBuildUserPersona();
|
||||
return !!(user?.personalization?.name && existingPersona?.workArea);
|
||||
}, [user?.personalization?.name]);
|
||||
return !!getBuildUserPersona()?.workArea;
|
||||
}, [user]);
|
||||
|
||||
// Check if all providers are configured (skip LLM step entirely if so)
|
||||
const allProvidersConfigured = useMemo(
|
||||
@@ -94,7 +93,7 @@ export function useOnboardingModal(): OnboardingModalController {
|
||||
);
|
||||
|
||||
// Auto-open initial onboarding modal on first load
|
||||
// Shows if: user info is missing OR (admin AND no providers configured)
|
||||
// Shows if: user info (role) missing OR (admin AND no providers configured)
|
||||
useEffect(() => {
|
||||
if (hasInitialized || isLoadingLlm || !user) return;
|
||||
|
||||
|
||||
26
web/src/app/nrf/(main)/layout.tsx
Normal file
26
web/src/app/nrf/(main)/layout.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import AppSidebar from "@/sections/sidebar/AppSidebar";
|
||||
import { getCurrentUserSS } from "@/lib/userSS";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* NRF Main (New Tab) Layout
|
||||
*
|
||||
* Shows the app sidebar when the user is authenticated.
|
||||
* This layout is NOT used by the side-panel route.
|
||||
*/
|
||||
export default async function Layout({ children }: LayoutProps) {
|
||||
noStore();
|
||||
|
||||
const user = await getCurrentUserSS();
|
||||
|
||||
return (
|
||||
<div className="flex flex-row w-full h-full">
|
||||
{user && <AppSidebar />}
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
31
web/src/app/nrf/(main)/page.tsx
Normal file
31
web/src/app/nrf/(main)/page.tsx
Normal file
@@ -0,0 +1,31 @@
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import NRFPage from "@/app/app/nrf/NRFPage";
|
||||
import { NRFPreferencesProvider } from "@/components/context/NRFPreferencesContext";
|
||||
import NRFChrome from "../NRFChrome";
|
||||
|
||||
/**
|
||||
* NRF (New Tab Page) Route - No Auth Required
|
||||
*
|
||||
* This route is placed outside /app/app/ to bypass the authentication
|
||||
* requirement in /app/app/layout.tsx. The NRFPage component handles
|
||||
* unauthenticated users gracefully by showing a login modal instead of
|
||||
* redirecting, which is better UX for the Chrome extension.
|
||||
*
|
||||
* Instead of AppLayouts.Root (which pulls in heavy Header state management),
|
||||
* we use NRFChrome — a lightweight overlay that renders only the search/chat
|
||||
* mode toggle and footer, floating transparently over NRFPage's background.
|
||||
*/
|
||||
export default async function Page() {
|
||||
noStore();
|
||||
|
||||
return (
|
||||
<div className="relative w-full h-full">
|
||||
<InstantSSRAutoRefresh />
|
||||
<NRFPreferencesProvider>
|
||||
<NRFPage />
|
||||
</NRFPreferencesProvider>
|
||||
<NRFChrome />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
148
web/src/app/nrf/NRFChrome.tsx
Normal file
148
web/src/app/nrf/NRFChrome.tsx
Normal file
@@ -0,0 +1,148 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
|
||||
import type { Components } from "react-markdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { OpenButton } from "@opal/components";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useAppSidebarContext } from "@/providers/AppSidebarProvider";
|
||||
import useScreenSize from "@/hooks/useScreenSize";
|
||||
|
||||
const footerMarkdownComponents = {
|
||||
p: ({ children }: { children?: React.ReactNode }) => (
|
||||
<Text as="p" text03 secondaryAction className="!my-0 text-center">
|
||||
{children}
|
||||
</Text>
|
||||
),
|
||||
a: ({
|
||||
href,
|
||||
className,
|
||||
children,
|
||||
...rest
|
||||
}: React.AnchorHTMLAttributes<HTMLAnchorElement>) => {
|
||||
const fullHref = ensureHrefProtocol(href);
|
||||
return (
|
||||
<a
|
||||
href={fullHref}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
{...rest}
|
||||
className={cn(className, "underline underline-offset-2")}
|
||||
>
|
||||
<Text text03 secondaryAction>
|
||||
{children}
|
||||
</Text>
|
||||
</a>
|
||||
);
|
||||
},
|
||||
} satisfies Partial<Components>;
|
||||
|
||||
/**
|
||||
* Lightweight chrome overlay for the NRF page.
|
||||
*
|
||||
* Renders only the search/chat mode toggle (top-left) and footer (bottom),
|
||||
* absolutely positioned so they float transparently over NRFPage's own
|
||||
* background. This avoids pulling in the full AppLayouts.Root Header which
|
||||
* carries heavy state management (share/delete/move modals) that the
|
||||
* extension doesn't need.
|
||||
*/
|
||||
export default function NRFChrome() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
const [modePopoverOpen, setModePopoverOpen] = useState(false);
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
|
||||
const customFooterContent =
|
||||
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
|
||||
`[Onyx ${
|
||||
settings?.webVersion || "dev"
|
||||
}](https://www.onyx.app/) - Open Source AI Platform`;
|
||||
|
||||
const showModeToggle =
|
||||
isPaidEnterpriseFeaturesEnabled &&
|
||||
appFocus.isNewSession() &&
|
||||
!classification;
|
||||
|
||||
const showHeader = isMobile || showModeToggle;
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Header chrome — top-left, mirrors position of settings button at top-right */}
|
||||
{showHeader && (
|
||||
<div className="absolute top-0 left-0 p-4 z-10 flex flex-row items-center gap-2">
|
||||
{isMobile && (
|
||||
<IconButton
|
||||
icon={SvgSidebar}
|
||||
onClick={() => setFolded(false)}
|
||||
internal
|
||||
/>
|
||||
)}
|
||||
{showModeToggle && (
|
||||
<Popover open={modePopoverOpen} onOpenChange={setModePopoverOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<OpenButton
|
||||
icon={
|
||||
effectiveMode === "search" ? SvgSearchMenu : SvgBubbleText
|
||||
}
|
||||
>
|
||||
{effectiveMode === "search" ? "Search" : "Chat"}
|
||||
</OpenButton>
|
||||
</Popover.Trigger>
|
||||
<Popover.Content align="start" width="lg">
|
||||
<Popover.Menu>
|
||||
<LineItem
|
||||
icon={SvgSearchMenu}
|
||||
selected={effectiveMode === "search"}
|
||||
description="Quick search for documents"
|
||||
onClick={noProp(() => {
|
||||
setAppMode("search");
|
||||
setModePopoverOpen(false);
|
||||
})}
|
||||
>
|
||||
Search
|
||||
</LineItem>
|
||||
<LineItem
|
||||
icon={SvgBubbleText}
|
||||
selected={effectiveMode === "chat"}
|
||||
description="Conversation and research"
|
||||
onClick={noProp(() => {
|
||||
setAppMode("chat");
|
||||
setModePopoverOpen(false);
|
||||
})}
|
||||
>
|
||||
Chat
|
||||
</LineItem>
|
||||
</Popover.Menu>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Footer — bottom-center, transparent background */}
|
||||
<footer className="absolute bottom-0 left-0 w-full z-10 flex flex-row justify-center items-center gap-2 px-2 pb-2 pointer-events-auto">
|
||||
<MinimalMarkdown
|
||||
content={customFooterContent}
|
||||
className="max-w-full text-center"
|
||||
components={footerMarkdownComponents}
|
||||
/>
|
||||
</footer>
|
||||
</>
|
||||
);
|
||||
}
|
||||
15
web/src/app/nrf/layout.tsx
Normal file
15
web/src/app/nrf/layout.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* NRF Root Layout - Shared by all NRF routes
|
||||
*
|
||||
* Provides ProjectsProvider (needed by NRFPage) without auth redirect.
|
||||
* Sidebar and chrome are handled by sub-layouts / individual pages.
|
||||
*/
|
||||
export default function Layout({ children }: LayoutProps) {
|
||||
return <ProjectsProvider>{children}</ProjectsProvider>;
|
||||
}
|
||||
24
web/src/app/nrf/side-panel/page.tsx
Normal file
24
web/src/app/nrf/side-panel/page.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import NRFPage from "@/app/app/nrf/NRFPage";
|
||||
import { NRFPreferencesProvider } from "@/components/context/NRFPreferencesContext";
|
||||
|
||||
/**
|
||||
* NRF Side Panel Route - No Auth Required
|
||||
*
|
||||
* Side panel variant — no NRFChrome overlay needed since the side panel
|
||||
* has its own header (logo + "Open in Onyx" button) and doesn't show
|
||||
* the mode toggle or footer.
|
||||
*/
|
||||
export default async function Page() {
|
||||
noStore();
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
<NRFPreferencesProvider>
|
||||
<NRFPage isSidePanel />
|
||||
</NRFPreferencesProvider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -49,7 +49,7 @@ export async function searchDocuments(
|
||||
const request: SendSearchQueryRequest = {
|
||||
search_query: query,
|
||||
filters: options?.filters,
|
||||
num_hits: options?.numHits ?? 50,
|
||||
num_hits: options?.numHits ?? 30,
|
||||
include_content: options?.includeContent ?? false,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
@@ -67,7 +67,7 @@ export function QueryControllerProvider({
|
||||
searchQuery,
|
||||
{
|
||||
filters,
|
||||
numHits: 50,
|
||||
numHits: 30,
|
||||
includeContent: false,
|
||||
signal: controller.signal,
|
||||
}
|
||||
|
||||
@@ -36,10 +36,11 @@
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { WithoutStyles } from "@/types";
|
||||
import { IconProps } from "@opal/types";
|
||||
import { IconFunctionComponent } from "@opal/types";
|
||||
import { HtmlHTMLAttributes, useEffect, useRef, useState } from "react";
|
||||
import { Content } from "@opal/layouts";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
|
||||
const widthClasses = {
|
||||
md: "w-[min(50rem,100%)]",
|
||||
@@ -163,7 +164,7 @@ function SettingsRoot({ width = "md", ...props }: SettingsRootProps) {
|
||||
* ```
|
||||
*/
|
||||
export interface SettingsHeaderProps {
|
||||
icon: React.FunctionComponent<IconProps>;
|
||||
icon: IconFunctionComponent;
|
||||
title: string;
|
||||
description?: string;
|
||||
children?: React.ReactNode;
|
||||
@@ -184,7 +185,10 @@ function SettingsHeader({
|
||||
}: SettingsHeaderProps) {
|
||||
const [showShadow, setShowShadow] = useState(false);
|
||||
const headerRef = useRef<HTMLDivElement>(null);
|
||||
const isSticky = !!rightChildren; //headers with actions are always sticky, others are not
|
||||
|
||||
// # NOTE (@Subash-Mohan)
|
||||
// Headers with actions are always sticky, others are not.
|
||||
const isSticky = !!rightChildren;
|
||||
|
||||
useEffect(() => {
|
||||
if (!isSticky) return;
|
||||
@@ -221,34 +225,35 @@ function SettingsHeader({
|
||||
<BackButton behaviorOverride={onBack} />
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={cn("flex flex-col gap-6 px-4", backButton ? "pt-2" : "pt-4")}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div className="flex flex-row justify-between items-center gap-4">
|
||||
<Icon className="stroke-text-04 h-[1.75rem] w-[1.75rem]" />
|
||||
{rightChildren}
|
||||
</div>
|
||||
<div className={cn("flex flex-col", separator ? "pb-6" : "pb-2")}>
|
||||
<div aria-label="admin-page-title">
|
||||
<Text as="p" headingH2>
|
||||
{title}
|
||||
</Text>
|
||||
</div>
|
||||
{description && (
|
||||
<Text secondaryBody text03>
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<Spacer vertical rem={1} />
|
||||
|
||||
<div className="flex flex-col gap-6 px-4">
|
||||
<div className="flex w-full justify-between">
|
||||
<div aria-label="admin-page-title">
|
||||
<Content
|
||||
icon={Icon}
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="headline"
|
||||
variant="heading"
|
||||
/>
|
||||
</div>
|
||||
{rightChildren}
|
||||
</div>
|
||||
|
||||
{children}
|
||||
</div>
|
||||
{separator && (
|
||||
|
||||
{separator ? (
|
||||
<>
|
||||
<Spacer vertical rem={1.5} />
|
||||
<Separator noPadding className="px-4" />
|
||||
</>
|
||||
) : (
|
||||
<Spacer vertical rem={0.5} />
|
||||
)}
|
||||
|
||||
{isSticky && (
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
@@ -11,10 +11,10 @@ export function getExtensionContext(): {
|
||||
return { isExtension: false, context: null };
|
||||
|
||||
const pathname = window.location.pathname;
|
||||
if (pathname.includes("/app/nrf/side-panel")) {
|
||||
if (pathname.includes("/nrf/side-panel")) {
|
||||
return { isExtension: true, context: "side_panel" };
|
||||
}
|
||||
if (pathname.includes("/app/nrf")) {
|
||||
if (pathname.includes("/nrf")) {
|
||||
return { isExtension: true, context: "new_tab" };
|
||||
}
|
||||
return { isExtension: false, context: null };
|
||||
|
||||
@@ -240,7 +240,7 @@ export interface SendSearchQueryRequest {
|
||||
filters?: BaseFilters | null;
|
||||
num_docs_fed_to_llm_selection?: number | null;
|
||||
run_query_expansion?: boolean;
|
||||
num_hits?: number; // default 50
|
||||
num_hits?: number; // default 30
|
||||
include_content?: boolean;
|
||||
stream?: boolean;
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ const heightClasses = {
|
||||
* </Modal.Content>
|
||||
* ```
|
||||
*/
|
||||
interface ModalContentProps
|
||||
export interface ModalContentProps
|
||||
extends WithoutStyles<
|
||||
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content>
|
||||
> {
|
||||
|
||||
21
web/src/refresh-components/PreviewImage.tsx
Normal file
21
web/src/refresh-components/PreviewImage.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface PreviewImageProps {
|
||||
src: string;
|
||||
alt: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function PreviewImage({
|
||||
src,
|
||||
alt,
|
||||
className,
|
||||
}: PreviewImageProps) {
|
||||
return (
|
||||
<img
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={cn("object-contain object-center", className)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgArrowLeft } from "@opal/icons";
|
||||
|
||||
export interface BackButtonProps {
|
||||
@@ -18,8 +18,8 @@ export default function BackButton({
|
||||
|
||||
return (
|
||||
<Button
|
||||
leftIcon={SvgArrowLeft}
|
||||
tertiary
|
||||
icon={SvgArrowLeft}
|
||||
prominence="tertiary"
|
||||
onClick={() => {
|
||||
if (behaviorOverride) {
|
||||
behaviorOverride();
|
||||
|
||||
@@ -14,7 +14,7 @@ import Tabs from "@/refresh-components/Tabs";
|
||||
import FilterButton from "@/refresh-components/buttons/FilterButton";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SEARCH_TOOL_ID,
|
||||
IMAGE_GENERATION_TOOL_ID,
|
||||
@@ -428,11 +428,13 @@ export default function AgentsNavigationPage() {
|
||||
title="Agents & Assistants"
|
||||
description="Customize AI behavior and knowledge for you and your team's use cases."
|
||||
rightChildren={
|
||||
<div data-testid="AgentsPage/new-agent-button">
|
||||
<Button href="/app/agents/create" leftIcon={SvgPlus}>
|
||||
New Agent
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
href="/app/agents/create"
|
||||
icon={SvgPlus}
|
||||
aria-label="AgentsPage/new-agent-button"
|
||||
>
|
||||
New Agent
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<div className="flex flex-col gap-2">
|
||||
|
||||
@@ -25,9 +25,7 @@ import { AppPopup } from "@/app/app/components/AppPopup";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import NoAssistantModal from "@/components/modals/NoAssistantModal";
|
||||
import TextViewModal from "@/sections/modals/TextViewModal";
|
||||
import CodeViewModal from "@/sections/modals/CodeViewModal";
|
||||
import { getCodeLanguage } from "@/lib/languages";
|
||||
import PreviewModal from "@/sections/modals/PreviewModal";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { useSendMessageToParent } from "@/lib/extension/utils";
|
||||
import { SUBMIT_MESSAGE_TYPES } from "@/lib/extension/constants";
|
||||
@@ -40,6 +38,7 @@ import useAgentController from "@/hooks/useAgentController";
|
||||
import useChatSessionController from "@/hooks/useChatSessionController";
|
||||
import useDeepResearchToggle from "@/hooks/useDeepResearchToggle";
|
||||
import useIsDefaultAgent from "@/hooks/useIsDefaultAgent";
|
||||
import AgentDescription from "@/app/app/components/AgentDescription";
|
||||
import {
|
||||
useChatSessionStore,
|
||||
useCurrentMessageHistory,
|
||||
@@ -686,18 +685,12 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{presentingDocument &&
|
||||
(getCodeLanguage(presentingDocument.semantic_identifier || "") ? (
|
||||
<CodeViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={() => setPresentingDocument(null)}
|
||||
/>
|
||||
) : (
|
||||
<TextViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={() => setPresentingDocument(null)}
|
||||
/>
|
||||
))}
|
||||
{presentingDocument && (
|
||||
<PreviewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={() => setPresentingDocument(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{stackTraceModalContent && (
|
||||
<ExceptionTraceModal
|
||||
@@ -889,6 +882,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
{/* ── Bottom: SearchResults + SourceFilter / Suggestions / ProjectChatList ── */}
|
||||
<div className="row-start-3 min-h-0 overflow-hidden flex flex-col items-center w-full">
|
||||
{/* Agent description below input */}
|
||||
{(appFocus.isNewSession() || appFocus.isAgent()) &&
|
||||
!isDefaultAgent && (
|
||||
<>
|
||||
<Spacer rem={1} />
|
||||
<AgentDescription agent={liveAssistant} />
|
||||
<Spacer rem={1.5} />
|
||||
</>
|
||||
)}
|
||||
{/* ProjectChatSessionList */}
|
||||
{appFocus.isProject() && (
|
||||
<div className="w-full max-w-[var(--app-page-main-content-width)] h-full overflow-y-auto overscroll-y-none mx-auto">
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useMemo } from "react";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgDownload } from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { getCodeLanguage } from "@/lib/languages";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { extractCodeText } from "@/app/app/message/codeUtils";
|
||||
import { fetchChatFile } from "@/lib/chat/svc";
|
||||
|
||||
export interface CodeViewProps {
|
||||
presentingDocument: MinimalOnyxDocument;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function CodeViewModal({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
}: CodeViewProps) {
|
||||
const [fileContent, setFileContent] = useState("");
|
||||
const [fileUrl, setFileUrl] = useState("");
|
||||
const [fileName, setFileName] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
|
||||
const language =
|
||||
getCodeLanguage(presentingDocument.semantic_identifier || "") ||
|
||||
"plaintext";
|
||||
|
||||
const lineCount = useMemo(() => {
|
||||
if (!fileContent) return 0;
|
||||
return fileContent.split("\n").length;
|
||||
}, [fileContent]);
|
||||
|
||||
const fileSize = useMemo(() => {
|
||||
if (!fileContent) return "";
|
||||
const bytes = new TextEncoder().encode(fileContent).length;
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
const kb = bytes / 1024;
|
||||
if (kb < 1024) return `${kb.toFixed(2)} KB`;
|
||||
const mb = kb / 1024;
|
||||
return `${mb.toFixed(2)} MB`;
|
||||
}, [fileContent]);
|
||||
|
||||
const headerDescription = useMemo(() => {
|
||||
if (!fileContent) return "";
|
||||
return `${language} - ${lineCount} ${
|
||||
lineCount === 1 ? "line" : "lines"
|
||||
} · ${fileSize}`;
|
||||
}, [fileContent, language, lineCount, fileSize]);
|
||||
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
setIsLoading(true);
|
||||
setLoadError(null);
|
||||
setFileContent("");
|
||||
const fileIdLocal =
|
||||
presentingDocument.document_id.split("__")[1] ||
|
||||
presentingDocument.document_id;
|
||||
|
||||
try {
|
||||
const response = await fetchChatFile(fileIdLocal);
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
setFileUrl((prev) => {
|
||||
if (prev) {
|
||||
window.URL.revokeObjectURL(prev);
|
||||
}
|
||||
return url;
|
||||
});
|
||||
setFileName(presentingDocument.semantic_identifier || "document");
|
||||
setFileContent(await blob.text());
|
||||
} catch {
|
||||
setLoadError("Failed to load document.");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
})();
|
||||
}, [presentingDocument]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (fileUrl) {
|
||||
window.URL.revokeObjectURL(fileUrl);
|
||||
}
|
||||
};
|
||||
}, [fileUrl]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
onClose();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Modal.Content
|
||||
width="md"
|
||||
height="lg"
|
||||
preventAccidentalClose={false}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<Modal.Header
|
||||
title={fileName || "Code"}
|
||||
description={headerDescription}
|
||||
onClose={onClose}
|
||||
/>
|
||||
|
||||
<Modal.Body padding={0} gap={0}>
|
||||
<Section padding={0} gap={0}>
|
||||
{isLoading ? (
|
||||
<Section>
|
||||
<SimpleLoader className="h-8 w-8" />
|
||||
</Section>
|
||||
) : loadError ? (
|
||||
<Section padding={1}>
|
||||
<Text text03 mainUiBody>
|
||||
{loadError}
|
||||
</Text>
|
||||
</Section>
|
||||
) : (
|
||||
<MinimalMarkdown
|
||||
content={`\`\`\`${language}\n${fileContent}\n\`\`\``}
|
||||
className="w-full h-full break-words"
|
||||
components={{
|
||||
code: ({
|
||||
node,
|
||||
className: codeClassName,
|
||||
children,
|
||||
...props
|
||||
}: any) => {
|
||||
const codeText = extractCodeText(
|
||||
node,
|
||||
fileContent,
|
||||
children
|
||||
);
|
||||
return (
|
||||
<CodeBlock className="" codeText={codeText}>
|
||||
{children}
|
||||
</CodeBlock>
|
||||
);
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text text03 mainContentMuted>
|
||||
{lineCount} {lineCount === 1 ? "line" : "lines"}
|
||||
</Text>
|
||||
<Section flexDirection="row" gap={0.5} width="fit">
|
||||
<CopyIconButton
|
||||
getCopyText={() => fileContent}
|
||||
tooltip="Copy code"
|
||||
size="sm"
|
||||
/>
|
||||
<a
|
||||
href={fileUrl}
|
||||
download={fileName || presentingDocument.document_id}
|
||||
>
|
||||
<Button
|
||||
icon={SvgDownload}
|
||||
tooltip="Download"
|
||||
size="sm"
|
||||
prominence="tertiary"
|
||||
/>
|
||||
</a>
|
||||
</Section>
|
||||
</Section>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
219
web/src/sections/modals/PreviewModal/PreviewModal.tsx
Normal file
219
web/src/sections/modals/PreviewModal/PreviewModal.tsx
Normal file
@@ -0,0 +1,219 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useCallback, useMemo } from "react";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { getCodeLanguage } from "@/lib/languages";
|
||||
import { fetchChatFile } from "@/lib/chat/svc";
|
||||
import { PreviewContext } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import { resolveVariant } from "@/sections/modals/PreviewModal/variants";
|
||||
|
||||
function resolveMimeType(mimeType: string, fileName: string): string {
|
||||
if (mimeType !== "application/octet-stream") return mimeType;
|
||||
const lower = fileName.toLowerCase();
|
||||
if (lower.endsWith(".md") || lower.endsWith(".markdown"))
|
||||
return "text/markdown";
|
||||
if (lower.endsWith(".txt")) return "text/plain";
|
||||
if (lower.endsWith(".csv")) return "text/csv";
|
||||
return mimeType;
|
||||
}
|
||||
|
||||
interface PreviewModalProps {
|
||||
presentingDocument: MinimalOnyxDocument;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function PreviewModal({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
}: PreviewModalProps) {
|
||||
const [fileContent, setFileContent] = useState("");
|
||||
const [fileUrl, setFileUrl] = useState("");
|
||||
const [fileName, setFileName] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
const [mimeType, setMimeType] = useState("application/octet-stream");
|
||||
const [zoom, setZoom] = useState(100);
|
||||
|
||||
const variant = useMemo(
|
||||
() => resolveVariant(presentingDocument.semantic_identifier, mimeType),
|
||||
[presentingDocument.semantic_identifier, mimeType]
|
||||
);
|
||||
|
||||
const language = useMemo(
|
||||
() =>
|
||||
getCodeLanguage(presentingDocument.semantic_identifier || "") ||
|
||||
"plaintext",
|
||||
[presentingDocument.semantic_identifier]
|
||||
);
|
||||
|
||||
const lineCount = useMemo(() => {
|
||||
if (!fileContent) return 0;
|
||||
return fileContent.split("\n").length;
|
||||
}, [fileContent]);
|
||||
|
||||
const fileSize = useMemo(() => {
|
||||
if (!fileContent) return "";
|
||||
const bytes = new TextEncoder().encode(fileContent).length;
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
const kb = bytes / 1024;
|
||||
if (kb < 1024) return `${kb.toFixed(2)} KB`;
|
||||
const mb = kb / 1024;
|
||||
return `${mb.toFixed(2)} MB`;
|
||||
}, [fileContent]);
|
||||
|
||||
const fetchFile = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
setLoadError(null);
|
||||
setFileContent("");
|
||||
const fileIdLocal =
|
||||
presentingDocument.document_id.split("__")[1] ||
|
||||
presentingDocument.document_id;
|
||||
|
||||
try {
|
||||
const response = await fetchChatFile(fileIdLocal);
|
||||
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
setFileUrl((prev) => {
|
||||
if (prev) window.URL.revokeObjectURL(prev);
|
||||
return url;
|
||||
});
|
||||
|
||||
const originalFileName =
|
||||
presentingDocument.semantic_identifier || "document";
|
||||
setFileName(originalFileName);
|
||||
|
||||
const rawContentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
const resolvedMime = resolveMimeType(rawContentType, originalFileName);
|
||||
setMimeType(resolvedMime);
|
||||
|
||||
const resolved = resolveVariant(
|
||||
presentingDocument.semantic_identifier,
|
||||
resolvedMime
|
||||
);
|
||||
if (resolved.needsTextContent) {
|
||||
setFileContent(await blob.text());
|
||||
}
|
||||
} catch {
|
||||
setLoadError("Failed to load document.");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [presentingDocument]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchFile();
|
||||
}, [fetchFile]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (fileUrl) window.URL.revokeObjectURL(fileUrl);
|
||||
};
|
||||
}, [fileUrl]);
|
||||
|
||||
const handleZoomIn = useCallback(
|
||||
() => setZoom((prev) => Math.min(prev + 25, 200)),
|
||||
[]
|
||||
);
|
||||
const handleZoomOut = useCallback(
|
||||
() => setZoom((prev) => Math.max(prev - 25, 25)),
|
||||
[]
|
||||
);
|
||||
|
||||
const ctx: PreviewContext = useMemo(
|
||||
() => ({
|
||||
fileContent,
|
||||
fileUrl,
|
||||
fileName,
|
||||
language,
|
||||
lineCount,
|
||||
fileSize,
|
||||
zoom,
|
||||
onZoomIn: handleZoomIn,
|
||||
onZoomOut: handleZoomOut,
|
||||
}),
|
||||
[
|
||||
fileContent,
|
||||
fileUrl,
|
||||
fileName,
|
||||
language,
|
||||
lineCount,
|
||||
fileSize,
|
||||
zoom,
|
||||
handleZoomIn,
|
||||
handleZoomOut,
|
||||
]
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<Modal.Content
|
||||
width={variant.width}
|
||||
height={variant.height}
|
||||
preventAccidentalClose={false}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<Modal.Header
|
||||
title={fileName || "Document"}
|
||||
description={variant.headerDescription(ctx)}
|
||||
onClose={onClose}
|
||||
/>
|
||||
|
||||
{/* Body + floating footer wrapper */}
|
||||
<Modal.Body padding={0} gap={0}>
|
||||
<Section padding={0} gap={0}>
|
||||
{isLoading ? (
|
||||
<Section>
|
||||
<SimpleLoader className="h-8 w-8" />
|
||||
</Section>
|
||||
) : loadError ? (
|
||||
<Section padding={1}>
|
||||
<Text text03 mainUiBody>
|
||||
{loadError}
|
||||
</Text>
|
||||
</Section>
|
||||
) : (
|
||||
variant.renderContent(ctx)
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
|
||||
{/* Floating footer */}
|
||||
{!isLoading && !loadError && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute bottom-0 left-0 right-0",
|
||||
"flex items-center justify-between",
|
||||
"p-4 pointer-events-none w-full"
|
||||
)}
|
||||
style={{
|
||||
background:
|
||||
"linear-gradient(to top, var(--background-tint-01) 40%, transparent)",
|
||||
}}
|
||||
>
|
||||
{/* Left slot */}
|
||||
<div className="pointer-events-auto">
|
||||
{variant.renderFooterLeft(ctx)}
|
||||
</div>
|
||||
|
||||
{/* Right slot */}
|
||||
<div className="pointer-events-auto rounded-12 bg-background-tint-00 p-1 shadow-lg">
|
||||
{variant.renderFooterRight(ctx)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
1
web/src/sections/modals/PreviewModal/index.ts
Normal file
1
web/src/sections/modals/PreviewModal/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { default } from "@/sections/modals/PreviewModal/PreviewModal";
|
||||
30
web/src/sections/modals/PreviewModal/interfaces.ts
Normal file
30
web/src/sections/modals/PreviewModal/interfaces.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import React from "react";
|
||||
import { ModalContentProps } from "@/refresh-components/Modal";
|
||||
|
||||
export interface PreviewContext {
|
||||
fileContent: string;
|
||||
fileUrl: string;
|
||||
fileName: string;
|
||||
language: string;
|
||||
lineCount: number;
|
||||
fileSize: string;
|
||||
zoom: number;
|
||||
onZoomIn: () => void;
|
||||
onZoomOut: () => void;
|
||||
}
|
||||
|
||||
export interface PreviewVariant
|
||||
extends Required<Pick<ModalContentProps, "width" | "height">> {
|
||||
/** Return true if this variant should handle the given file. */
|
||||
matches: (semanticIdentifier: string | null, mimeType: string) => boolean;
|
||||
/** Whether the fetcher should read the blob as text. */
|
||||
needsTextContent: boolean;
|
||||
/** String shown below the title in the modal header. */
|
||||
headerDescription: (ctx: PreviewContext) => string;
|
||||
/** Body content. */
|
||||
renderContent: (ctx: PreviewContext) => React.ReactNode;
|
||||
/** Left side of the floating footer (e.g. line count text, zoom controls). Return null for nothing. */
|
||||
renderFooterLeft: (ctx: PreviewContext) => React.ReactNode;
|
||||
/** Right side of the floating footer (e.g. copy + download buttons). */
|
||||
renderFooterRight: (ctx: PreviewContext) => React.ReactNode;
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { getCodeLanguage } from "@/lib/languages";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { extractCodeText } from "@/app/app/message/codeUtils";
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import {
|
||||
CopyButton,
|
||||
DownloadButton,
|
||||
} from "@/sections/modals/PreviewModal/variants/shared";
|
||||
|
||||
export const codeVariant: PreviewVariant = {
|
||||
matches: (name) => !!getCodeLanguage(name || ""),
|
||||
width: "md",
|
||||
height: "lg",
|
||||
needsTextContent: true,
|
||||
|
||||
headerDescription: (ctx) =>
|
||||
ctx.fileContent
|
||||
? `${ctx.language} - ${ctx.lineCount} ${
|
||||
ctx.lineCount === 1 ? "line" : "lines"
|
||||
} · ${ctx.fileSize}`
|
||||
: "",
|
||||
|
||||
renderContent: (ctx) => (
|
||||
<MinimalMarkdown
|
||||
content={`\`\`\`${ctx.language}\n${ctx.fileContent}\n\n\`\`\``}
|
||||
className="w-full break-words h-full"
|
||||
components={{
|
||||
code: ({ node, children }: any) => {
|
||||
const codeText = extractCodeText(node, ctx.fileContent, children);
|
||||
return (
|
||||
<CodeBlock className="" codeText={codeText}>
|
||||
{children}
|
||||
</CodeBlock>
|
||||
);
|
||||
},
|
||||
}}
|
||||
/>
|
||||
),
|
||||
|
||||
renderFooterLeft: (ctx) => (
|
||||
<Text text03 mainUiBody className="select-none">
|
||||
{ctx.lineCount} {ctx.lineCount === 1 ? "line" : "lines"}
|
||||
</Text>
|
||||
),
|
||||
|
||||
renderFooterRight: (ctx) => (
|
||||
<Section flexDirection="row" width="fit">
|
||||
<CopyButton getText={() => ctx.fileContent} />
|
||||
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
|
||||
</Section>
|
||||
),
|
||||
};
|
||||
102
web/src/sections/modals/PreviewModal/variants/csvVariant.tsx
Normal file
102
web/src/sections/modals/PreviewModal/variants/csvVariant.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import {
|
||||
CopyButton,
|
||||
DownloadButton,
|
||||
} from "@/sections/modals/PreviewModal/variants/shared";
|
||||
|
||||
interface CsvData {
|
||||
headers: string[];
|
||||
rows: string[][];
|
||||
}
|
||||
|
||||
function parseCsv(content: string): CsvData {
|
||||
const lines = content.split(/\r?\n/).filter((l) => l.length > 0);
|
||||
const headers = lines.length > 0 ? lines[0]?.split(",") ?? [] : [];
|
||||
const rows = lines.slice(1).map((line) => line.split(","));
|
||||
return { headers, rows };
|
||||
}
|
||||
|
||||
export const csvVariant: PreviewVariant = {
|
||||
matches: (name, mime) =>
|
||||
mime.startsWith("text/csv") || (name || "").toLowerCase().endsWith(".csv"),
|
||||
width: "lg",
|
||||
height: "full",
|
||||
needsTextContent: true,
|
||||
headerDescription: (ctx) => {
|
||||
if (!ctx.fileContent) return "";
|
||||
const { rows } = parseCsv(ctx.fileContent);
|
||||
return `CSV - ${rows.length} rows · ${ctx.fileSize}`;
|
||||
},
|
||||
|
||||
renderContent: (ctx) => {
|
||||
if (!ctx.fileContent) return null;
|
||||
const { headers, rows } = parseCsv(ctx.fileContent);
|
||||
return (
|
||||
<Section justifyContent="start" alignItems="start" padding={1}>
|
||||
<Table>
|
||||
<TableHeader className="sticky top-0 z-sticky">
|
||||
<TableRow className="bg-background-tint-02">
|
||||
{headers.map((h: string, i: number) => (
|
||||
<TableHead key={i}>
|
||||
<Text
|
||||
as="p"
|
||||
className="line-clamp-2 font-medium"
|
||||
text03
|
||||
mainUiBody
|
||||
>
|
||||
{h}
|
||||
</Text>
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{rows.map((row: string[], rIdx: number) => (
|
||||
<TableRow key={rIdx}>
|
||||
{headers.map((_: string, cIdx: number) => (
|
||||
<TableCell
|
||||
key={cIdx}
|
||||
className={cn(
|
||||
cIdx === 0 && "sticky left-0 bg-background-tint-01",
|
||||
"py-0 px-4 whitespace-normal break-words"
|
||||
)}
|
||||
>
|
||||
{row?.[cIdx] ?? ""}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</Section>
|
||||
);
|
||||
},
|
||||
|
||||
renderFooterLeft: (ctx) => {
|
||||
if (!ctx.fileContent) return null;
|
||||
const { headers, rows } = parseCsv(ctx.fileContent);
|
||||
return (
|
||||
<Text text03 mainUiBody className="select-none">
|
||||
{headers.length} {headers.length === 1 ? "column" : "columns"} ·{" "}
|
||||
{rows.length} {rows.length === 1 ? "row" : "rows"}
|
||||
</Text>
|
||||
);
|
||||
},
|
||||
renderFooterRight: (ctx) => (
|
||||
<Section flexDirection="row" width="fit">
|
||||
<CopyButton getText={() => ctx.fileContent} />
|
||||
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
|
||||
</Section>
|
||||
),
|
||||
};
|
||||
@@ -0,0 +1,45 @@
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import PreviewImage from "@/refresh-components/PreviewImage";
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import {
|
||||
DownloadButton,
|
||||
ZoomControls,
|
||||
} from "@/sections/modals/PreviewModal/variants/shared";
|
||||
|
||||
export const imageVariant: PreviewVariant = {
|
||||
matches: (_name, mime) => mime.startsWith("image/"),
|
||||
width: "lg",
|
||||
height: "full",
|
||||
needsTextContent: false,
|
||||
headerDescription: () => "",
|
||||
|
||||
renderContent: (ctx) => (
|
||||
<div
|
||||
className="flex flex-1 min-h-0 items-center justify-center p-4 transition-transform duration-300 ease-in-out"
|
||||
style={{
|
||||
transform: `scale(${ctx.zoom / 100})`,
|
||||
transformOrigin: "center",
|
||||
}}
|
||||
>
|
||||
<PreviewImage
|
||||
src={ctx.fileUrl}
|
||||
alt={ctx.fileName}
|
||||
className="max-w-full max-h-full"
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
|
||||
renderFooterLeft: (ctx) => (
|
||||
<ZoomControls
|
||||
zoom={ctx.zoom}
|
||||
onZoomIn={ctx.onZoomIn}
|
||||
onZoomOut={ctx.onZoomOut}
|
||||
/>
|
||||
),
|
||||
|
||||
renderFooterRight: (ctx) => (
|
||||
<Section flexDirection="row" width="fit">
|
||||
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
|
||||
</Section>
|
||||
),
|
||||
};
|
||||
25
web/src/sections/modals/PreviewModal/variants/index.ts
Normal file
25
web/src/sections/modals/PreviewModal/variants/index.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import { codeVariant } from "@/sections/modals/PreviewModal/variants/codeVariant";
|
||||
import { imageVariant } from "@/sections/modals/PreviewModal/variants/imageVariant";
|
||||
import { pdfVariant } from "@/sections/modals/PreviewModal/variants/pdfVariant";
|
||||
import { csvVariant } from "@/sections/modals/PreviewModal/variants/csvVariant";
|
||||
import { markdownVariant } from "@/sections/modals/PreviewModal/variants/markdownVariant";
|
||||
import { unsupportedVariant } from "@/sections/modals/PreviewModal/variants/unsupportedVariant";
|
||||
|
||||
const PREVIEW_VARIANTS: PreviewVariant[] = [
|
||||
codeVariant,
|
||||
imageVariant,
|
||||
pdfVariant,
|
||||
csvVariant,
|
||||
markdownVariant,
|
||||
];
|
||||
|
||||
export function resolveVariant(
|
||||
semanticIdentifier: string | null,
|
||||
mimeType: string
|
||||
): PreviewVariant {
|
||||
return (
|
||||
PREVIEW_VARIANTS.find((v) => v.matches(semanticIdentifier, mimeType)) ??
|
||||
unsupportedVariant
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import {
|
||||
CopyButton,
|
||||
DownloadButton,
|
||||
} from "@/sections/modals/PreviewModal/variants/shared";
|
||||
|
||||
const MARKDOWN_MIMES = [
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/plain",
|
||||
"text/x-rst",
|
||||
"text/x-org",
|
||||
];
|
||||
|
||||
export const markdownVariant: PreviewVariant = {
|
||||
matches: (name, mime) => {
|
||||
if (MARKDOWN_MIMES.some((m) => mime.startsWith(m))) return true;
|
||||
const lower = (name || "").toLowerCase();
|
||||
return (
|
||||
lower.endsWith(".md") ||
|
||||
lower.endsWith(".markdown") ||
|
||||
lower.endsWith(".txt") ||
|
||||
lower.endsWith(".rst") ||
|
||||
lower.endsWith(".org")
|
||||
);
|
||||
},
|
||||
width: "lg",
|
||||
height: "full",
|
||||
needsTextContent: true,
|
||||
headerDescription: () => "",
|
||||
|
||||
renderContent: (ctx) => (
|
||||
<ScrollIndicatorDiv className="flex-1 min-h-0 p-4" variant="shadow">
|
||||
<MinimalMarkdown
|
||||
content={ctx.fileContent}
|
||||
className="w-full pb-4 h-full text-lg break-words"
|
||||
/>
|
||||
</ScrollIndicatorDiv>
|
||||
),
|
||||
|
||||
renderFooterLeft: () => null,
|
||||
|
||||
renderFooterRight: (ctx) => (
|
||||
<Section flexDirection="row" width="fit">
|
||||
<CopyButton getText={() => ctx.fileContent} />
|
||||
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
|
||||
</Section>
|
||||
),
|
||||
};
|
||||
26
web/src/sections/modals/PreviewModal/variants/pdfVariant.tsx
Normal file
26
web/src/sections/modals/PreviewModal/variants/pdfVariant.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import { DownloadButton } from "@/sections/modals/PreviewModal/variants/shared";
|
||||
|
||||
export const pdfVariant: PreviewVariant = {
|
||||
matches: (_name, mime) => mime === "application/pdf",
|
||||
width: "lg",
|
||||
height: "full",
|
||||
needsTextContent: false,
|
||||
headerDescription: () => "",
|
||||
|
||||
renderContent: (ctx) => (
|
||||
<iframe
|
||||
src={`${ctx.fileUrl}#toolbar=0`}
|
||||
className="w-full h-full flex-1 min-h-0 border-none"
|
||||
title="PDF Viewer"
|
||||
/>
|
||||
),
|
||||
|
||||
renderFooterLeft: () => null,
|
||||
renderFooterRight: (ctx) => (
|
||||
<Section flexDirection="row" width="fit">
|
||||
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
|
||||
</Section>
|
||||
),
|
||||
};
|
||||
65
web/src/sections/modals/PreviewModal/variants/shared.tsx
Normal file
65
web/src/sections/modals/PreviewModal/variants/shared.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgDownload, SvgZoomIn, SvgZoomOut } from "@opal/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
interface DownloadButtonProps {
|
||||
fileUrl: string;
|
||||
fileName: string;
|
||||
}
|
||||
|
||||
export function DownloadButton({ fileUrl, fileName }: DownloadButtonProps) {
|
||||
return (
|
||||
<a href={fileUrl} download={fileName}>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgDownload}
|
||||
tooltip="Download"
|
||||
/>
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
interface CopyButtonProps {
|
||||
getText: () => string;
|
||||
}
|
||||
|
||||
export function CopyButton({ getText }: CopyButtonProps) {
|
||||
return (
|
||||
<CopyIconButton getCopyText={getText} tooltip="Copy content" size="sm" />
|
||||
);
|
||||
}
|
||||
|
||||
interface ZoomControlsProps {
|
||||
zoom: number;
|
||||
onZoomIn: () => void;
|
||||
onZoomOut: () => void;
|
||||
}
|
||||
|
||||
export function ZoomControls({ zoom, onZoomIn, onZoomOut }: ZoomControlsProps) {
|
||||
return (
|
||||
<div className="rounded-12 bg-background-tint-00 p-1 shadow-lg">
|
||||
<Section flexDirection="row" width="fit">
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgZoomOut}
|
||||
onClick={onZoomOut}
|
||||
tooltip="Zoom Out"
|
||||
/>
|
||||
<Text mainUiMono text03>
|
||||
{zoom}%
|
||||
</Text>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgZoomIn}
|
||||
onClick={onZoomIn}
|
||||
tooltip="Zoom In"
|
||||
/>
|
||||
</Section>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { PreviewVariant } from "@/sections/modals/PreviewModal/interfaces";
|
||||
import { DownloadButton } from "@/sections/modals/PreviewModal/variants/shared";
|
||||
|
||||
export const unsupportedVariant: PreviewVariant = {
|
||||
matches: () => true,
|
||||
width: "lg",
|
||||
height: "full",
|
||||
needsTextContent: false,
|
||||
headerDescription: () => "",
|
||||
|
||||
renderContent: (ctx) => (
|
||||
<div className="flex flex-col items-center justify-center flex-1 min-h-0 gap-4 p-6">
|
||||
<Text as="p" text03 mainUiBody>
|
||||
This file format is not supported for preview.
|
||||
</Text>
|
||||
<a href={ctx.fileUrl} download={ctx.fileName}>
|
||||
<Button>Download File</Button>
|
||||
</a>
|
||||
</div>
|
||||
),
|
||||
|
||||
renderFooterLeft: () => null,
|
||||
renderFooterRight: (ctx) => (
|
||||
<DownloadButton fileUrl={ctx.fileUrl} fileName={ctx.fileName} />
|
||||
),
|
||||
};
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
SvgZoomIn,
|
||||
SvgZoomOut,
|
||||
} from "@opal/icons";
|
||||
import PreviewImage from "@/refresh-components/PreviewImage";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -249,10 +250,10 @@ export default function TextViewModal({
|
||||
style={{ transform: `scale(${zoom / 100})` }}
|
||||
>
|
||||
{isImageFormat(fileType) ? (
|
||||
<img
|
||||
<PreviewImage
|
||||
src={fileUrl}
|
||||
alt={fileName}
|
||||
className="w-full flex-1 min-h-0 object-contain object-center"
|
||||
className="w-full flex-1 min-h-0"
|
||||
/>
|
||||
) : isSupportedIframeFormat(fileType) ? (
|
||||
<iframe
|
||||
|
||||
@@ -136,7 +136,7 @@ async function verifyAdminPageNavigation(
|
||||
|
||||
try {
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
pageTitle,
|
||||
new RegExp(`^${pageTitle}`),
|
||||
{
|
||||
timeout: 10000,
|
||||
}
|
||||
|
||||
@@ -21,12 +21,12 @@ function getToolSwitch(page: Page, toolName: string): Locator {
|
||||
}
|
||||
|
||||
/**
|
||||
* Click a tool switch and wait for the PATCH response to complete.
|
||||
* Click a button and wait for the PATCH response to complete.
|
||||
* Uses waitForResponse set up *before* the click to avoid race conditions.
|
||||
*/
|
||||
async function clickToolSwitchAndWaitForSave(
|
||||
async function clickAndWaitForPatch(
|
||||
page: Page,
|
||||
switchLocator: Locator
|
||||
buttonLocator: Locator
|
||||
): Promise<void> {
|
||||
const patchPromise = page.waitForResponse(
|
||||
(r) =>
|
||||
@@ -34,7 +34,7 @@ async function clickToolSwitchAndWaitForSave(
|
||||
r.request().method() === "PATCH",
|
||||
{ timeout: 8000 }
|
||||
);
|
||||
await switchLocator.click();
|
||||
await buttonLocator.click();
|
||||
await patchPromise;
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
}) => {
|
||||
// Verify page loads with expected content
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
"Chat Preferences"
|
||||
/^Chat Preferences/
|
||||
);
|
||||
await expect(page.getByText("Actions & Tools")).toBeVisible();
|
||||
});
|
||||
@@ -215,7 +215,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
);
|
||||
|
||||
// Toggle back to original state
|
||||
await clickToolSwitchAndWaitForSave(page, searchSwitch);
|
||||
await clickAndWaitForPatch(page, searchSwitch);
|
||||
});
|
||||
|
||||
test("should toggle Web Search tool on and off", async ({ page }) => {
|
||||
@@ -267,7 +267,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
);
|
||||
|
||||
// Toggle back to original state
|
||||
await clickToolSwitchAndWaitForSave(page, webSearchSwitch);
|
||||
await clickAndWaitForPatch(page, webSearchSwitch);
|
||||
});
|
||||
|
||||
test("should toggle Image Generation tool on and off", async ({ page }) => {
|
||||
@@ -321,7 +321,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
);
|
||||
|
||||
// Toggle back to original state
|
||||
await clickToolSwitchAndWaitForSave(page, imageGenSwitch);
|
||||
await clickAndWaitForPatch(page, imageGenSwitch);
|
||||
});
|
||||
|
||||
test("should edit and save system prompt", async ({ page }) => {
|
||||
@@ -339,30 +339,12 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
const textarea = modal.getByPlaceholder("Enter your system prompt...");
|
||||
await textarea.fill(testPrompt);
|
||||
|
||||
// Set up response listener before the click to avoid race conditions
|
||||
const patchRespPromise = page.waitForResponse(
|
||||
(r) =>
|
||||
r.url().includes("/api/admin/default-assistant") &&
|
||||
r.request().method() === "PATCH",
|
||||
{ timeout: 8000 }
|
||||
// Click Save and wait for PATCH to complete
|
||||
await clickAndWaitForPatch(
|
||||
page,
|
||||
modal.getByRole("button", { name: "Save" })
|
||||
);
|
||||
|
||||
// Click Save in the modal footer
|
||||
await modal.getByRole("button", { name: "Save" }).click();
|
||||
|
||||
// Wait for PATCH to complete
|
||||
const patchResp = await patchRespPromise;
|
||||
console.log(
|
||||
`[prompt] Save PATCH status=${patchResp.status()} body=${(
|
||||
await patchResp.text()
|
||||
).slice(0, 300)}`
|
||||
);
|
||||
|
||||
// Wait for success toast
|
||||
await expect(page.getByText("System prompt updated")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Modal should close after save
|
||||
await expect(modal).not.toBeVisible();
|
||||
|
||||
@@ -396,11 +378,10 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
// If already empty, add some text first
|
||||
if (initialValue === "") {
|
||||
await textarea.fill("Temporary text");
|
||||
await modal.getByRole("button", { name: "Save" }).click();
|
||||
await expect(page.getByText("System prompt updated")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
await page.waitForTimeout(500);
|
||||
await clickAndWaitForPatch(
|
||||
page,
|
||||
modal.getByRole("button", { name: "Save" })
|
||||
);
|
||||
// Reopen modal
|
||||
await page.getByText("Modify Prompt").click();
|
||||
await expect(modal).toBeVisible({ timeout: 5000 });
|
||||
@@ -409,28 +390,12 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
// Clear the textarea
|
||||
await textarea.fill("");
|
||||
|
||||
// Set up response listener before the click to avoid race conditions
|
||||
const patchRespPromise = page.waitForResponse(
|
||||
(r) =>
|
||||
r.url().includes("/api/admin/default-assistant") &&
|
||||
r.request().method() === "PATCH",
|
||||
{ timeout: 8000 }
|
||||
);
|
||||
|
||||
// Save
|
||||
await modal.getByRole("button", { name: "Save" }).click();
|
||||
|
||||
const patchResp = await patchRespPromise;
|
||||
console.log(
|
||||
`[prompt-empty] Save empty PATCH status=${patchResp.status()} body=${(
|
||||
await patchResp.text()
|
||||
).slice(0, 300)}`
|
||||
await clickAndWaitForPatch(
|
||||
page,
|
||||
modal.getByRole("button", { name: "Save" })
|
||||
);
|
||||
|
||||
await expect(page.getByText("System prompt updated")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Refresh page to verify persistence
|
||||
await page.reload();
|
||||
await page.waitForLoadState("networkidle");
|
||||
@@ -450,10 +415,10 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
// Restore original value if it wasn't already empty
|
||||
if (initialValue !== "") {
|
||||
await textareaAfter.fill(initialValue);
|
||||
await modalAfter.getByRole("button", { name: "Save" }).click();
|
||||
await expect(page.getByText("System prompt updated")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
await clickAndWaitForPatch(
|
||||
page,
|
||||
modalAfter.getByRole("button", { name: "Save" })
|
||||
);
|
||||
} else {
|
||||
await modalAfter.getByRole("button", { name: "Cancel" }).click();
|
||||
}
|
||||
@@ -475,27 +440,12 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
|
||||
await textarea.fill(longPrompt);
|
||||
|
||||
// Set up response listener before the click to avoid race conditions
|
||||
const patchRespPromise = page.waitForResponse(
|
||||
(r) =>
|
||||
r.url().includes("/api/admin/default-assistant") &&
|
||||
r.request().method() === "PATCH",
|
||||
{ timeout: 8000 }
|
||||
);
|
||||
|
||||
// Save
|
||||
await modal.getByRole("button", { name: "Save" }).click();
|
||||
const patchResp = await patchRespPromise;
|
||||
console.log(
|
||||
`[prompt-long] Save PATCH status=${patchResp.status()} body=${(
|
||||
await patchResp.text()
|
||||
).slice(0, 300)}`
|
||||
await clickAndWaitForPatch(
|
||||
page,
|
||||
modal.getByRole("button", { name: "Save" })
|
||||
);
|
||||
|
||||
await expect(page.getByText("System prompt updated")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Verify persistence after reload
|
||||
await page.reload();
|
||||
await page.waitForLoadState("networkidle");
|
||||
@@ -513,10 +463,10 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
"Enter your system prompt..."
|
||||
);
|
||||
await restoreTextarea.fill(initialValue);
|
||||
await modalAfter.getByRole("button", { name: "Save" }).click();
|
||||
await expect(page.getByText("System prompt updated")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
await clickAndWaitForPatch(
|
||||
page,
|
||||
modalAfter.getByRole("button", { name: "Save" })
|
||||
);
|
||||
} else {
|
||||
await modalAfter.getByRole("button", { name: "Cancel" }).click();
|
||||
}
|
||||
@@ -602,7 +552,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
const toolSwitch = getToolSwitch(page, toolName);
|
||||
const currentState = await toolSwitch.getAttribute("aria-checked");
|
||||
if (currentState === "true") {
|
||||
await clickToolSwitchAndWaitForSave(page, toolSwitch);
|
||||
await clickAndWaitForPatch(page, toolSwitch);
|
||||
const newState = await toolSwitch.getAttribute("aria-checked");
|
||||
console.log(`[toggle-all] Clicked ${toolName}, new state=${newState}`);
|
||||
}
|
||||
@@ -628,7 +578,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
const toolSwitch = getToolSwitch(page, toolName);
|
||||
const currentState = await toolSwitch.getAttribute("aria-checked");
|
||||
if (currentState === "false") {
|
||||
await clickToolSwitchAndWaitForSave(page, toolSwitch);
|
||||
await clickAndWaitForPatch(page, toolSwitch);
|
||||
const newState = await toolSwitch.getAttribute("aria-checked");
|
||||
console.log(`[toggle-all] Clicked ${toolName}, new state=${newState}`);
|
||||
}
|
||||
@@ -722,7 +672,7 @@ test.describe("Chat Preferences Admin Page", () => {
|
||||
const originalState = toolStates[toolName];
|
||||
|
||||
if (currentState !== originalState) {
|
||||
await clickToolSwitchAndWaitForSave(page, toolSwitch);
|
||||
await clickAndWaitForPatch(page, toolSwitch);
|
||||
needsSave = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ test.describe("Disable Default Assistant Setting @exclusive", () => {
|
||||
|
||||
// Wait for the page to fully render (page title signals form is loaded)
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
"Chat Preferences",
|
||||
/^Chat Preferences/,
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
@@ -224,7 +224,7 @@ test.describe("Disable Default Assistant Setting @exclusive", () => {
|
||||
|
||||
// Verify the page title
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
"Chat Preferences"
|
||||
/^Chat Preferences/
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ test.describe("LLM Provider Setup @exclusive", () => {
|
||||
await loginAs(page, "admin");
|
||||
await page.goto(LLM_SETUP_URL);
|
||||
await page.waitForLoadState("networkidle");
|
||||
await expect(page.getByLabel("admin-page-title")).toHaveText("LLM Setup");
|
||||
await expect(page.getByLabel("admin-page-title")).toHaveText(/^LLM Setup/);
|
||||
});
|
||||
|
||||
test.afterEach(async ({ page }) => {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { test, expect, Page, Browser } from "@playwright/test";
|
||||
import { loginAs, loginAsRandomUser } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
// --- Locator Helper Functions ---
|
||||
const getNameInput = (page: Page) => page.locator('input[name="name"]');
|
||||
@@ -244,14 +245,14 @@ test.describe("Assistant Creation and Edit Verification", () => {
|
||||
await loginAsRandomUser(page);
|
||||
|
||||
// --- Initial Values ---
|
||||
const assistantName = `Test Assistant ${Date.now()}`;
|
||||
const assistantName = "Test Assistant 1";
|
||||
const assistantDescription = "This is a test assistant description.";
|
||||
const assistantInstructions = "These are the test instructions.";
|
||||
const assistantReminder = "Initial reminder.";
|
||||
const assistantStarterMessage = "Initial starter message?";
|
||||
|
||||
// --- Edited Values ---
|
||||
const editedAssistantName = `Edited Assistant ${Date.now()}`;
|
||||
const editedAssistantName = "Edited Assistant";
|
||||
const editedAssistantDescription = "This is the edited description.";
|
||||
const editedAssistantInstructions = "These are the edited instructions.";
|
||||
const editedAssistantReminder = "Edited reminder.";
|
||||
@@ -296,6 +297,7 @@ test.describe("Assistant Creation and Edit Verification", () => {
|
||||
expect(assistantIdMatch).toBeTruthy();
|
||||
const assistantId = assistantIdMatch ? assistantIdMatch[1] : null;
|
||||
expect(assistantId).not.toBeNull();
|
||||
await expectScreenshot(page, { name: "welcome-page-with-assistant" });
|
||||
|
||||
// Store assistant ID for cleanup
|
||||
knowledgeAssistantId = Number(assistantId);
|
||||
|
||||
@@ -137,7 +137,7 @@ test.describe("Signup flow", () => {
|
||||
|
||||
// Wait for error message to appear
|
||||
await expect(
|
||||
page.getByText("Unknown error", { exact: true })
|
||||
page.getByText("Disposable email addresses are not allowed").first()
|
||||
).toBeVisible();
|
||||
|
||||
// Capture the error state with hidden email to avoid non-deterministic diffs
|
||||
|
||||
@@ -107,7 +107,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
|
||||
// Create a custom assistant to test non-default behavior
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByTestId("AgentsPage/new-agent-button").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
@@ -150,7 +150,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
}) => {
|
||||
// Create a custom assistant
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByTestId("AgentsPage/new-agent-button").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
@@ -200,7 +200,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
}) => {
|
||||
// Create a custom assistant with starter messages
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByTestId("AgentsPage/new-agent-button").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
@@ -253,7 +253,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
// Wait for modal or assistant list to appear
|
||||
// The selector might be in a modal or dropdown.
|
||||
await page
|
||||
.getByTestId("AgentsPage/new-agent-button")
|
||||
.getByLabel("AgentsPage/new-agent-button")
|
||||
.waitFor({ state: "visible", timeout: 5000 });
|
||||
|
||||
// Look for default assistant by name - it should NOT be there
|
||||
@@ -280,7 +280,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
}) => {
|
||||
// Create a custom assistant
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByTestId("AgentsPage/new-agent-button").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
@@ -143,14 +143,14 @@ test.describe("File preview modal from chat file links", () => {
|
||||
// Verify the file name is shown in the header
|
||||
await expect(modal.getByText("notes.txt")).toBeVisible();
|
||||
|
||||
// Verify the download button exists
|
||||
await expect(modal.getByText("Download File")).toBeVisible();
|
||||
// Verify the download link exists
|
||||
await expect(modal.locator("a[download]")).toBeVisible();
|
||||
|
||||
// Verify the file content is rendered
|
||||
await expect(modal.getByText("Hello from the mock file!")).toBeVisible();
|
||||
});
|
||||
|
||||
test("clicking a code file link opens the CodeViewModal with syntax highlighting", async ({
|
||||
test("clicking a code file link opens the PreviewModal with syntax highlighting", async ({
|
||||
page,
|
||||
}) => {
|
||||
const mockContent = `Here is your script: [app.py](/api/chat/file/${MOCK_FILE_ID})`;
|
||||
@@ -173,7 +173,7 @@ test.describe("File preview modal from chat file links", () => {
|
||||
await expect(fileLink).toBeVisible({ timeout: 5000 });
|
||||
await fileLink.click();
|
||||
|
||||
// Verify the CodeViewModal opens
|
||||
// Verify the PreviewModal opens
|
||||
const modal = page.getByRole("dialog");
|
||||
await expect(modal).toBeVisible({ timeout: 5000 });
|
||||
|
||||
@@ -217,9 +217,9 @@ test.describe("File preview modal from chat file links", () => {
|
||||
const modal = page.getByRole("dialog");
|
||||
await expect(modal).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Click the download button and verify a download starts
|
||||
// Click the download link and verify a download starts
|
||||
const downloadPromise = page.waitForEvent("download");
|
||||
await modal.getByText("Download File").last().click();
|
||||
await modal.locator("a[download]").last().click();
|
||||
const download = await downloadPromise;
|
||||
|
||||
expect(download.suggestedFilename()).toContain("data.csv");
|
||||
|
||||
@@ -32,7 +32,7 @@ test("Chat workflow", async ({ page }) => {
|
||||
|
||||
// Test creation of a new assistant
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByTestId("AgentsPage/new-agent-button").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page.locator('input[name="name"]').click();
|
||||
await page.locator('input[name="name"]').fill("Test Assistant");
|
||||
await page.locator('textarea[name="description"]').click();
|
||||
|
||||
@@ -9,6 +9,7 @@ import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
const PREFLIGHT_TIMEOUT_MS = 60_000;
|
||||
const PREFLIGHT_POLL_INTERVAL_MS = 2_000;
|
||||
const PREFLIGHT_WARN_AFTER_MS = 15_000;
|
||||
|
||||
/**
|
||||
* Poll the health endpoint until the server is ready or we time out.
|
||||
@@ -17,6 +18,8 @@ const PREFLIGHT_POLL_INTERVAL_MS = 2_000;
|
||||
async function waitForServer(baseURL: string): Promise<void> {
|
||||
const healthURL = baseURL;
|
||||
const deadline = Date.now() + PREFLIGHT_TIMEOUT_MS;
|
||||
const startTime = Date.now();
|
||||
let warned = false;
|
||||
|
||||
console.log(`[global-setup] Waiting for server at ${healthURL} ...`);
|
||||
|
||||
@@ -31,6 +34,18 @@ async function waitForServer(baseURL: string): Promise<void> {
|
||||
} catch {
|
||||
// Connection refused / DNS error — server not up yet.
|
||||
}
|
||||
|
||||
if (!warned && Date.now() - startTime >= PREFLIGHT_WARN_AFTER_MS) {
|
||||
warned = true;
|
||||
console.warn(
|
||||
`[global-setup] ⚠ Still waiting for server after ${
|
||||
PREFLIGHT_WARN_AFTER_MS / 1000
|
||||
}s.\n` +
|
||||
` Please verify that both the backend and frontend are running.\n` +
|
||||
` You can start them with: ods compose dev`
|
||||
);
|
||||
}
|
||||
|
||||
await new Promise((r) => setTimeout(r, PREFLIGHT_POLL_INTERVAL_MS));
|
||||
}
|
||||
|
||||
@@ -39,7 +54,7 @@ async function waitForServer(baseURL: string): Promise<void> {
|
||||
`Timed out after ${
|
||||
PREFLIGHT_TIMEOUT_MS / 1000
|
||||
}s waiting for ${healthURL} to return 200. ` +
|
||||
`Make sure the server is running (e.g. \`ods compose dev\`).`
|
||||
`Make sure the backend and frontend are running (e.g. \`ods compose dev\`).`
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -496,10 +496,7 @@ test.describe("Default Assistant MCP Integration", () => {
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.waitForURL("**/app/agents");
|
||||
|
||||
await page
|
||||
.getByTestId("AgentsPage/new-agent-button")
|
||||
.getByRole("link", { name: "New Agent" })
|
||||
.click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page.waitForURL("**/app/agents/create");
|
||||
|
||||
const assistantName = `MCP Assistant ${Date.now()}`;
|
||||
|
||||
@@ -20,7 +20,7 @@ export async function createAssistant(page: Page, params: AssistantParams) {
|
||||
|
||||
// Open Assistants modal/list
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByTestId("AgentsPage/new-agent-button").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
|
||||
// Fill required fields
|
||||
await page.locator('input[name="name"]').fill(name);
|
||||
|
||||
Reference in New Issue
Block a user