Compare commits

..

1 Commits

Author SHA1 Message Date
Bo-Onyx
3e25853eb8 chore(hook): Add feature control 2026-03-12 19:28:48 -07:00
46 changed files with 834 additions and 3652 deletions

View File

@@ -1,102 +1,67 @@
name: Post-Merge Beta Cherry-Pick
on:
pull_request_target:
types:
- closed
push:
branches:
- main
# SECURITY NOTE:
# This workflow intentionally uses pull_request_target so post-merge automation can
# use base-repo credentials. Do not checkout PR head refs in this workflow
# (e.g. github.event.pull_request.head.sha). Only trusted base refs are allowed.
permissions:
contents: read
jobs:
resolve-cherry-pick-request:
if: >-
github.event.pull_request.merged == true
&& github.event.pull_request.base.ref == 'main'
&& github.event.pull_request.head.repo.full_name == github.repository
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
merge_commit_sha: ${{ steps.gate.outputs.merge_commit_sha }}
merged_by: ${{ steps.gate.outputs.merged_by }}
gate_error: ${{ steps.gate.outputs.gate_error }}
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
PR_NUMBER: ${{ github.event.pull_request.number }}
# SECURITY: keep PR body in env/plain-text handling; avoid directly
# inlining github.event.pull_request.body into shell commands.
PR_BODY: ${{ github.event.pull_request.body }}
MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }}
MERGED_BY: ${{ github.event.pull_request.merged_by.login }}
# GitHub team slug authorized to trigger cherry-picks (e.g. "core-eng").
# For private/secret teams the GITHUB_TOKEN may need org:read scope;
# visible teams work with the default token.
ALLOWED_TEAM: "onyx-core-team"
run: |
echo "pr_number=${PR_NUMBER}" >> "$GITHUB_OUTPUT"
echo "merged_by=${MERGED_BY}" >> "$GITHUB_OUTPUT"
if ! echo "${PR_BODY}" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${PR_NUMBER}. Skipping."
exit 0
fi
# Keep should_cherrypick output before any possible exit 1 below so
# notify-slack can still gate on this output even if this job fails.
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${PR_NUMBER}."
if [ -z "${MERGE_COMMIT_SHA}" ] || [ "${MERGE_COMMIT_SHA}" = "null" ]; then
echo "gate_error=missing-merge-commit-sha" >> "$GITHUB_OUTPUT"
echo "::error::PR #${PR_NUMBER} requested cherry-pick, but merge_commit_sha is missing."
exit 1
fi
echo "merge_commit_sha=${MERGE_COMMIT_SHA}" >> "$GITHUB_OUTPUT"
member_state_file="$(mktemp)"
member_err_file="$(mktemp)"
if ! gh api "orgs/${GITHUB_REPOSITORY_OWNER}/teams/${ALLOWED_TEAM}/memberships/${MERGED_BY}" --jq '.state' >"${member_state_file}" 2>"${member_err_file}"; then
api_err="$(tr '\n' ' ' < "${member_err_file}" | sed 's/[[:space:]]\+/ /g' | cut -c1-300)"
echo "gate_error=team-api-error" >> "$GITHUB_OUTPUT"
echo "::error::Team membership API call failed for ${MERGED_BY} in ${ALLOWED_TEAM}: ${api_err}"
exit 1
fi
member_state="$(cat "${member_state_file}")"
if [ "${member_state}" != "active" ]; then
echo "gate_error=not-team-member" >> "$GITHUB_OUTPUT"
echo "::error::${MERGED_BY} is not an active member of team ${ALLOWED_TEAM} (state: ${member_state}). Failing cherry-pick gate."
exit 1
fi
exit 0
cherry-pick-to-latest-release:
needs:
- resolve-cherry-pick-request
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success'
permissions:
contents: write
pull-requests: write
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
# SECURITY: keep checkout pinned to trusted base branch; do not switch to PR head refs.
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
@@ -104,37 +69,31 @@ jobs:
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
set +e
uv run --no-sync --with onyx-devtools ods cherry-pick "${MERGE_COMMIT_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
pipe_statuses=("${PIPESTATUS[@]}")
exit_code="${pipe_statuses[0]}"
tee_exit="${pipe_statuses[1]:-0}"
set -e
if [ "${tee_exit}" -ne 0 ]; then
echo "status=failure" >> "$GITHUB_OUTPUT"
echo "reason=output-capture-failed" >> "$GITHUB_OUTPUT"
echo "::error::tee failed to capture cherry-pick output (exit ${tee_exit}); cannot classify result."
exit 1
fi
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
@@ -156,7 +115,7 @@ jobs:
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.run_cherry_pick.outputs.status == 'failure'
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
@@ -165,9 +124,8 @@ jobs:
notify-slack-on-cherry-pick-failure:
needs:
- resolve-cherry-pick-request
- cherry-pick-to-latest-release
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
@@ -176,49 +134,22 @@ jobs:
with:
persist-credentials: false
- name: Fail if Slack webhook secret is missing
env:
CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
run: |
if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then
echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured."
exit 1
fi
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }}
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
GATE_ERROR: ${{ needs.resolve-cherry-pick-request.outputs.gate_error }}
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${GATE_ERROR}" = "missing-merge-commit-sha" ]; then
reason_text="requested cherry-pick but merge commit SHA was missing"
elif [ "${GATE_ERROR}" = "team-api-error" ]; then
reason_text="team membership lookup failed while validating cherry-pick permissions"
elif [ "${GATE_ERROR}" = "not-team-member" ]; then
reason_text="merger is not an active member of the allowed team"
elif [ "${CHERRY_PICK_REASON}" = "output-capture-failed" ]; then
reason_text="failed to capture cherry-pick output for classification"
elif [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
if [ -n "${GATE_ERROR}" ]; then
failed_job_label="resolve-cherry-pick-request"
else
failed_job_label="cherry-pick-to-latest-release"
fi
failed_jobs="• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${MERGE_COMMIT_SHA}" ]; then
failed_jobs="${failed_jobs}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
fi
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
@@ -231,4 +162,4 @@ jobs:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.event.pull_request.base.ref }}
ref-name: ${{ github.ref_name }}

View File

@@ -1,5 +1,3 @@
import base64
import hashlib
import json
import os
import random
@@ -31,7 +29,6 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
@@ -58,7 +55,6 @@ from fastapi_users.router.common import ErrorModel
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import GetAccessTokenError
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import nulls_last
@@ -124,10 +120,6 @@ from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import onyx_error_to_json_response
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
@@ -1629,7 +1621,6 @@ STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
STATE_TOKEN_LIFETIME_SECONDS = 3600
CSRF_TOKEN_KEY = "csrftoken"
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
PKCE_COOKIE_NAME_PREFIX = "fastapiusersoauthpkce"
class OAuth2AuthorizeResponse(BaseModel):
@@ -1650,21 +1641,6 @@ def generate_csrf_token() -> str:
return secrets.token_urlsafe(32)
def _base64url_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def generate_pkce_pair() -> tuple[str, str]:
verifier = secrets.token_urlsafe(64)
challenge = _base64url_encode(hashlib.sha256(verifier.encode("ascii")).digest())
return verifier, challenge
def get_pkce_cookie_name(state: str) -> str:
state_hash = hashlib.sha256(state.encode("utf-8")).hexdigest()
return f"{PKCE_COOKIE_NAME_PREFIX}_{state_hash}"
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_onyx_oauth_router(
oauth_client: BaseOAuth2,
@@ -1673,7 +1649,6 @@ def create_onyx_oauth_router(
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
enable_pkce: bool = False,
) -> APIRouter:
return get_oauth_router(
oauth_client,
@@ -1683,7 +1658,6 @@ def create_onyx_oauth_router(
redirect_url,
associate_by_email,
is_verified_by_default,
enable_pkce=enable_pkce,
)
@@ -1702,7 +1676,6 @@ def get_oauth_router(
csrf_token_cookie_secure: Optional[bool] = None,
csrf_token_cookie_httponly: bool = True,
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
enable_pkce: bool = False,
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
@@ -1719,13 +1692,6 @@ def get_oauth_router(
route_name=callback_route_name,
)
async def null_access_token_state() -> tuple[OAuth2Token, Optional[str]] | None:
return None
access_token_state_dependency = (
oauth2_authorize_callback if not enable_pkce else null_access_token_state
)
if csrf_token_cookie_secure is None:
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
@@ -1759,26 +1725,13 @@ def get_oauth_router(
CSRF_TOKEN_KEY: csrf_token,
}
state = generate_state_token(state_data, state_secret)
pkce_cookie: tuple[str, str] | None = None
if enable_pkce:
code_verifier, code_challenge = generate_pkce_pair()
pkce_cookie_name = get_pkce_cookie_name(state)
pkce_cookie = (pkce_cookie_name, code_verifier)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
code_challenge=code_challenge,
code_challenge_method="S256",
)
else:
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
@@ -1786,15 +1739,11 @@ def get_oauth_router(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
def set_oauth_cookie(
target_response: Response,
*,
key: str,
value: str,
) -> None:
target_response.set_cookie(
key=key,
value=value,
if redirect:
redirect_response = RedirectResponse(authorization_url, status_code=302)
redirect_response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
@@ -1802,28 +1751,18 @@ def get_oauth_router(
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return redirect_response
response_with_cookies: Response
if redirect:
response_with_cookies = RedirectResponse(authorization_url, status_code=302)
else:
response_with_cookies = response
set_oauth_cookie(
response_with_cookies,
response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
if pkce_cookie is not None:
pkce_cookie_name, code_verifier = pkce_cookie
set_oauth_cookie(
response_with_cookies,
key=pkce_cookie_name,
value=code_verifier,
)
if redirect:
return response_with_cookies
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@@ -1854,242 +1793,119 @@ def get_oauth_router(
)
async def callback(
request: Request,
access_token_state: Tuple[OAuth2Token, Optional[str]] | None = Depends(
access_token_state_dependency
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
code: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
) -> Response:
pkce_cookie_name: str | None = None
) -> RedirectResponse:
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(
token["access_token"]
)
def delete_pkce_cookie(response: Response) -> None:
if enable_pkce and pkce_cookie_name:
response.delete_cookie(
key=pkce_cookie_name,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
def build_error_response(exc: OnyxError) -> JSONResponse:
log_onyx_error(exc)
error_response = onyx_error_to_json_response(exc)
delete_pkce_cookie(error_response)
return error_response
def decode_and_validate_state(state_value: str) -> Dict[str, str]:
try:
state_data = decode_jwt(
state_value, state_secret, [STATE_TOKEN_AUDIENCE]
)
except jwt.DecodeError:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
getattr(
ErrorCode,
"ACCESS_TOKEN_DECODE_ERROR",
"ACCESS_TOKEN_DECODE_ERROR",
),
)
except jwt.ExpiredSignatureError:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
getattr(
ErrorCode,
"ACCESS_TOKEN_ALREADY_EXPIRED",
"ACCESS_TOKEN_ALREADY_EXPIRED",
),
)
except jwt.PyJWTError:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
getattr(
ErrorCode,
"ACCESS_TOKEN_DECODE_ERROR",
"ACCESS_TOKEN_DECODE_ERROR",
),
)
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
if (
not cookie_csrf_token
or not state_csrf_token
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
):
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
)
return state_data
token: OAuth2Token
state_data: Dict[str, str]
# `code`, `state`, and `error` are read directly only in the PKCE path.
# In the non-PKCE path, `oauth2_authorize_callback` consumes them.
if enable_pkce:
if state is not None:
pkce_cookie_name = get_pkce_cookie_name(state)
if error is not None:
return build_error_response(
OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authorization request failed or was denied",
)
)
if code is None:
return build_error_response(
OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Missing authorization code in OAuth callback",
)
)
if state is None:
return build_error_response(
OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Missing state parameter in OAuth callback",
)
)
state_value = state
if redirect_url is not None:
callback_redirect_url = redirect_url
else:
callback_path = request.app.url_path_for(callback_route_name)
callback_redirect_url = f"{WEB_DOMAIN}{callback_path}"
code_verifier = request.cookies.get(cast(str, pkce_cookie_name))
if not code_verifier:
return build_error_response(
OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Missing PKCE verifier cookie in OAuth callback",
)
)
try:
state_data = decode_and_validate_state(state_value)
except OnyxError as e:
return build_error_response(e)
try:
token = await oauth_client.get_access_token(
code, callback_redirect_url, code_verifier
)
except GetAccessTokenError:
return build_error_response(
OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authorization code exchange failed",
)
)
else:
if access_token_state is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR, "Missing OAuth callback state"
)
token, callback_state = access_token_state
if callback_state is None:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Missing state parameter in OAuth callback",
)
state_data = decode_and_validate_state(callback_state)
async def complete_login_flow(
token: OAuth2Token, state_data: Dict[str, str]
) -> RedirectResponse:
account_id, account_email = await oauth_client.get_id_email(
token["access_token"]
if account_email is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
)
if account_email is None:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
)
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
),
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode,
"ACCESS_TOKEN_ALREADY_EXPIRED",
"ACCESS_TOKEN_ALREADY_EXPIRED",
),
)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
if (
not cookie_csrf_token
or not state_csrf_token
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
)
request.state.referral_source = referral_source
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
except UserAlreadyExists:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
ErrorCode.OAUTH_USER_ALREADY_EXISTS,
)
request.state.referral_source = referral_source
if not user.is_active:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
except UserAlreadyExists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
)
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_destination = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(
redirect_destination, status_code=302
)
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
for header_name, header_value in response.headers.items():
# FastAPI can have multiple Set-Cookie headers as a list
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
for cookie_value in header_value:
redirect_response.headers.append(header_name, cookie_value)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
for header_name, header_value in response.headers.items():
header_name_lower = header_name.lower()
if header_name_lower == "set-cookie":
redirect_response.headers.append(header_name, header_value)
continue
if header_name_lower in {"location", "content-length"}:
continue
redirect_response.headers[header_name] = header_value
return redirect_response
if hasattr(response, "body"):
redirect_response.body = response.body
if hasattr(response, "status_code"):
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
if enable_pkce:
try:
redirect_response = await complete_login_flow(token, state_data)
except OnyxError as e:
return build_error_response(e)
delete_pkce_cookie(redirect_response)
return redirect_response
return await complete_login_flow(token, state_data)
return redirect_response
return router

View File

@@ -196,10 +196,6 @@ if _OIDC_SCOPE_OVERRIDE:
except Exception:
pass
# Enables PKCE for OIDC login flow. Disabled by default to preserve
# backwards compatibility for existing OIDC deployments.
OIDC_PKCE_ENABLED = os.environ.get("OIDC_PKCE_ENABLED", "").lower() == "true"
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/onyx/configs/saml_config"
@@ -1046,6 +1042,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
#####

View File

@@ -33,7 +33,6 @@ from office365.runtime.queries.client_query import ClientQuery # type: ignore[i
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from pydantic import BaseModel
from pydantic import Field
from requests.exceptions import HTTPError
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
@@ -278,7 +277,7 @@ def _log_and_raise_for_status(response: requests.Response) -> None:
try:
response.raise_for_status()
except Exception:
logger.error(f"HTTP request failed: {response.text}")
logger.error(f"Graph API request failed: {response.text}")
raise
@@ -1259,14 +1258,7 @@ class SharepointConnector(
total_yielded = 0
while page_url:
try:
data = self._graph_api_get_json(page_url, params)
except HTTPError as e:
if e.response.status_code == 404:
logger.warning(f"Site page not found: {page_url}")
break
raise
data = self._graph_api_get_json(page_url, params)
params = None # nextLink already embeds query params
for page in data.get("value", []):

View File

@@ -59,22 +59,6 @@ class OnyxError(Exception):
return self._status_code_override or self.error_code.status_code
def log_onyx_error(exc: OnyxError) -> None:
detail = exc.detail
status_code = exc.status_code
if status_code >= 500:
logger.error(f"OnyxError {exc.error_code.code}: {detail}")
elif status_code >= 400:
logger.warning(f"OnyxError {exc.error_code.code}: {detail}")
def onyx_error_to_json_response(exc: OnyxError) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content=exc.error_code.detail(exc.detail),
)
def register_onyx_exception_handlers(app: FastAPI) -> None:
"""Register a global handler that converts ``OnyxError`` to JSON responses.
@@ -87,5 +71,13 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
request: Request, # noqa: ARG001
exc: OnyxError,
) -> JSONResponse:
log_onyx_error(exc)
return onyx_error_to_json_response(exc)
status_code = exc.status_code
if status_code >= 500:
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
elif status_code >= 400:
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
return JSONResponse(
status_code=status_code,
content=exc.error_code.detail(exc.detail),
)

View File

View File

@@ -0,0 +1,26 @@
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from shared_configs.configs import MULTI_TENANT
def require_hook_enabled() -> None:
"""FastAPI dependency that gates all hook management endpoints.
Hooks are only available in single-tenant / self-hosted deployments with
HOOK_ENABLED=true explicitly set. Two layers of protection:
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
2. HOOK_ENABLED flag — explicit opt-in by the operator
Use as: Depends(require_hook_enabled)
"""
if MULTI_TENANT:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"Custom code hooks are not available in multi-tenant deployments",
)
if not HOOK_ENABLED:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"Custom code hooks are not enabled. Set HOOK_ENABLED=true to enable.",
)

View File

@@ -44,7 +44,6 @@ from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import OAUTH_ENABLED
from onyx.configs.app_configs import OIDC_PKCE_ENABLED
from onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from onyx.configs.app_configs import OPENID_CONFIG_URL
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -598,7 +597,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
associate_by_email=True,
is_verified_by_default=True,
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
enable_pkce=OIDC_PKCE_ENABLED,
),
prefix="/auth/oidc",
)

View File

@@ -1,4 +1,3 @@
import re
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
@@ -41,9 +40,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_TEMPLATES_DIR = Path(__file__).parent / "templates"
_WEBAPP_HMR_FIXER_TEMPLATE = (_TEMPLATES_DIR / "webapp_hmr_fixer.js").read_text()
def require_onyx_craft_enabled(user: User = Depends(current_user)) -> User:
"""
@@ -243,62 +239,18 @@ def _stream_response(response: httpx.Response) -> Iterator[bytes]:
yield chunk
def _inject_hmr_fixer(content: bytes, session_id: str) -> bytes:
"""Inject a script that stubs root-scoped Next HMR websocket connections."""
base = f"/api/build/sessions/{session_id}/webapp"
script = f"<script>{_WEBAPP_HMR_FIXER_TEMPLATE.replace('__WEBAPP_BASE__', base)}</script>"
text = content.decode("utf-8")
text = re.sub(
r"(<head\b[^>]*>)",
lambda m: m.group(0) + script,
text,
count=1,
flags=re.IGNORECASE,
)
return text.encode("utf-8")
def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes:
"""Rewrite Next.js asset paths to go through the proxy."""
import re
# Base path includes session_id for routing
webapp_base_path = f"/api/build/sessions/{session_id}/webapp"
escaped_webapp_base_path = webapp_base_path.replace("/", r"\/")
hmr_paths = ("/_next/webpack-hmr", "/_next/hmr")
text = content.decode("utf-8")
# Anchor on delimiter so already-prefixed URLs (from assetPrefix) aren't double-rewritten.
for delim in ('"', "'", "("):
text = text.replace(f"{delim}/_next/", f"{delim}{webapp_base_path}/_next/")
text = re.sub(
rf"{re.escape(delim)}https?://[^/\"')]+/_next/",
f"{delim}{webapp_base_path}/_next/",
text,
)
text = re.sub(
rf"{re.escape(delim)}wss?://[^/\"')]+/_next/",
f"{delim}{webapp_base_path}/_next/",
text,
)
text = text.replace(r"\/_next\/", rf"{escaped_webapp_base_path}\/_next\/")
text = re.sub(
r"https?:\\\/\\\/[^\"']+?\\\/_next\\\/",
rf"{escaped_webapp_base_path}\/_next\/",
text,
)
text = re.sub(
r"wss?:\\\/\\\/[^\"']+?\\\/_next\\\/",
rf"{escaped_webapp_base_path}\/_next\/",
text,
)
for hmr_path in hmr_paths:
escaped_hmr_path = hmr_path.replace("/", r"\/")
text = text.replace(
f"{webapp_base_path}{hmr_path}",
hmr_path,
)
text = text.replace(
f"{escaped_webapp_base_path}{escaped_hmr_path}",
escaped_hmr_path,
)
# Rewrite /_next/ paths to go through our proxy
text = text.replace("/_next/", f"{webapp_base_path}/_next/")
# Rewrite JSON data file fetch paths (e.g., /data.json, /data/tickets.json)
# Matches paths like "/filename.json" or "/path/to/file.json"
text = re.sub(
r'"(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)"',
f'"{webapp_base_path}\\1"',
@@ -309,29 +261,11 @@ def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes:
f"'{webapp_base_path}\\1'",
text,
)
# Rewrite favicon
text = text.replace('"/favicon.ico', f'"{webapp_base_path}/favicon.ico')
return text.encode("utf-8")
def _rewrite_proxy_response_headers(
headers: dict[str, str], session_id: str
) -> dict[str, str]:
"""Rewrite response headers that can leak root-scoped asset URLs."""
link = headers.get("link")
if link:
webapp_base_path = f"/api/build/sessions/{session_id}/webapp"
rewritten_link = re.sub(
r"<https?://[^>]+/_next/",
f"<{webapp_base_path}/_next/",
link,
)
rewritten_link = rewritten_link.replace(
"</_next/", f"<{webapp_base_path}/_next/"
)
headers["link"] = rewritten_link
return headers
# Content types that may contain asset path references that need rewriting
REWRITABLE_CONTENT_TYPES = {
"text/html",
@@ -408,17 +342,12 @@ def _proxy_request(
for key, value in response.headers.items()
if key.lower() not in EXCLUDED_HEADERS
}
response_headers = _rewrite_proxy_response_headers(
response_headers, str(session_id)
)
content_type = response.headers.get("content-type", "")
# For HTML/CSS/JS responses, rewrite asset paths
if any(ct in content_type for ct in REWRITABLE_CONTENT_TYPES):
content = _rewrite_asset_paths(response.content, str(session_id))
if "text/html" in content_type:
content = _inject_hmr_fixer(content, str(session_id))
return Response(
content=content,
status_code=response.status_code,
@@ -462,7 +391,7 @@ def _check_webapp_access(
return session
_OFFLINE_HTML_PATH = _TEMPLATES_DIR / "webapp_offline.html"
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
def _offline_html_response() -> Response:
@@ -470,7 +399,6 @@ def _offline_html_response() -> Response:
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
terminal window aesthetic with Minecraft-themed typing animation.
"""
html = _OFFLINE_HTML_PATH.read_text()
return Response(content=html, status_code=503, media_type="text/html")

View File

@@ -1,135 +0,0 @@
(function () {
var WEBAPP_BASE = "__WEBAPP_BASE__";
var PROXIED_NEXT_PREFIX = WEBAPP_BASE + "/_next/";
var PROXIED_HMR_PREFIX = WEBAPP_BASE + "/_next/webpack-hmr";
var PROXIED_ALT_HMR_PREFIX = WEBAPP_BASE + "/_next/hmr";
function isHmrWebSocketUrl(url) {
if (!url) return false;
try {
var parsedUrl = new URL(String(url), window.location.href);
return (
parsedUrl.pathname.indexOf("/_next/webpack-hmr") === 0 ||
parsedUrl.pathname.indexOf("/_next/hmr") === 0 ||
parsedUrl.pathname.indexOf(PROXIED_HMR_PREFIX) === 0 ||
parsedUrl.pathname.indexOf(PROXIED_ALT_HMR_PREFIX) === 0
);
} catch (e) {}
if (typeof url === "string") {
return (
url.indexOf("/_next/webpack-hmr") === 0 ||
url.indexOf("/_next/hmr") === 0 ||
url.indexOf(PROXIED_HMR_PREFIX) === 0 ||
url.indexOf(PROXIED_ALT_HMR_PREFIX) === 0
);
}
return false;
}
function rewriteNextAssetUrl(url) {
if (!url) return url;
try {
var parsedUrl = new URL(String(url), window.location.href);
if (parsedUrl.pathname.indexOf(PROXIED_NEXT_PREFIX) === 0) {
return parsedUrl.pathname + parsedUrl.search + parsedUrl.hash;
}
if (parsedUrl.pathname.indexOf("/_next/") === 0) {
return (
WEBAPP_BASE + parsedUrl.pathname + parsedUrl.search + parsedUrl.hash
);
}
} catch (e) {}
if (typeof url === "string") {
if (url.indexOf(PROXIED_NEXT_PREFIX) === 0) {
return url;
}
if (url.indexOf("/_next/") === 0) {
return WEBAPP_BASE + url;
}
}
return url;
}
function createEvent(eventType) {
return typeof Event === "function"
? new Event(eventType)
: { type: eventType };
}
function MockHmrWebSocket(url) {
this.url = String(url);
this.readyState = 1;
this.bufferedAmount = 0;
this.extensions = "";
this.protocol = "";
this.binaryType = "blob";
this.onopen = null;
this.onmessage = null;
this.onerror = null;
this.onclose = null;
this._l = {};
var socket = this;
setTimeout(function () {
socket._d("open", createEvent("open"));
}, 0);
}
MockHmrWebSocket.CONNECTING = 0;
MockHmrWebSocket.OPEN = 1;
MockHmrWebSocket.CLOSING = 2;
MockHmrWebSocket.CLOSED = 3;
MockHmrWebSocket.prototype.addEventListener = function (eventType, callback) {
(this._l[eventType] || (this._l[eventType] = [])).push(callback);
};
MockHmrWebSocket.prototype.removeEventListener = function (
eventType,
callback,
) {
var listeners = this._l[eventType] || [];
this._l[eventType] = listeners.filter(function (listener) {
return listener !== callback;
});
};
MockHmrWebSocket.prototype._d = function (eventType, eventValue) {
var listeners = this._l[eventType] || [];
for (var i = 0; i < listeners.length; i++) {
listeners[i].call(this, eventValue);
}
var handler = this["on" + eventType];
if (typeof handler === "function") {
handler.call(this, eventValue);
}
};
MockHmrWebSocket.prototype.send = function () {};
MockHmrWebSocket.prototype.close = function (code, reason) {
if (this.readyState >= 2) return;
this.readyState = 3;
var closeEvent = createEvent("close");
closeEvent.code = code === undefined ? 1000 : code;
closeEvent.reason = reason || "";
closeEvent.wasClean = true;
this._d("close", closeEvent);
};
if (window.WebSocket) {
var OriginalWebSocket = window.WebSocket;
window.WebSocket = function (url, protocols) {
if (isHmrWebSocketUrl(url)) {
return new MockHmrWebSocket(rewriteNextAssetUrl(url));
}
return protocols === undefined
? new OriginalWebSocket(url)
: new OriginalWebSocket(url, protocols);
};
window.WebSocket.prototype = OriginalWebSocket.prototype;
Object.setPrototypeOf(window.WebSocket, OriginalWebSocket);
["CONNECTING", "OPEN", "CLOSING", "CLOSED"].forEach(function (stateKey) {
window.WebSocket[stateKey] = OriginalWebSocket[stateKey];
});
}
})();

View File

@@ -1,256 +0,0 @@
"""Unit tests for webapp proxy path rewriting/injection."""
from types import SimpleNamespace
from typing import cast
from typing import Literal
from uuid import UUID
import httpx
import pytest
from fastapi import Request
from sqlalchemy.orm import Session
from onyx.server.features.build.api import api
from onyx.server.features.build.api.api import _inject_hmr_fixer
from onyx.server.features.build.api.api import _rewrite_asset_paths
from onyx.server.features.build.api.api import _rewrite_proxy_response_headers
SESSION_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
BASE = f"/api/build/sessions/{SESSION_ID}/webapp"
def rewrite(html: str) -> str:
return _rewrite_asset_paths(html.encode(), SESSION_ID).decode()
def inject(html: str) -> str:
return _inject_hmr_fixer(html.encode(), SESSION_ID).decode()
class TestNextjsPathRewriting:
def test_rewrites_bare_next_script_src(self) -> None:
html = '<script src="/_next/static/chunks/main.js">'
result = rewrite(html)
assert f'src="{BASE}/_next/static/chunks/main.js"' in result
assert '"/_next/' not in result
def test_rewrites_bare_next_in_single_quotes(self) -> None:
html = "<link href='/_next/static/css/app.css'>"
result = rewrite(html)
assert f"'{BASE}/_next/static/css/app.css'" in result
def test_rewrites_bare_next_in_url_parens(self) -> None:
html = "background: url(/_next/static/media/font.woff2)"
result = rewrite(html)
assert f"url({BASE}/_next/static/media/font.woff2)" in result
def test_no_double_prefix_when_already_proxied(self) -> None:
"""assetPrefix makes Next.js emit already-prefixed URLs — must not double-rewrite."""
already_prefixed = f'<script src="{BASE}/_next/static/chunks/main.js">'
result = rewrite(already_prefixed)
# Should be unchanged
assert result == already_prefixed
# Specifically, no double path
assert f"{BASE}/{BASE}" not in result
def test_rewrites_favicon(self) -> None:
html = '<link rel="icon" href="/favicon.ico">'
result = rewrite(html)
assert f'"{BASE}/favicon.ico"' in result
def test_rewrites_json_data_path_double_quoted(self) -> None:
html = 'fetch("/data/tickets.json")'
result = rewrite(html)
assert f'"{BASE}/data/tickets.json"' in result
def test_rewrites_json_data_path_single_quoted(self) -> None:
html = "fetch('/data/items.json')"
result = rewrite(html)
assert f"'{BASE}/data/items.json'" in result
def test_rewrites_escaped_next_font_path_in_json_script(self) -> None:
"""Next dev can embed font asset paths in JSON-escaped script payloads."""
html = r'{"src":"\/_next\/static\/media\/font.woff2"}'
result = rewrite(html)
assert (
r'{"src":"\/api\/build\/sessions\/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\/webapp\/_next\/static\/media\/font.woff2"}'
in result
)
def test_rewrites_escaped_next_font_path_in_style_payload(self) -> None:
"""Keep dynamically generated next/font URLs inside the session proxy."""
html = r'{"css":"@font-face{src:url(\"\/_next\/static\/media\/font.woff2\")"}'
result = rewrite(html)
assert (
r"\/api\/build\/sessions\/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\/webapp\/_next\/static\/media\/font.woff2"
in result
)
def test_rewrites_absolute_next_font_url(self) -> None:
html = (
'<link rel="preload" as="font" '
'href="https://craft-dev.onyx.app/_next/static/media/font.woff2">'
)
result = rewrite(html)
assert f'"{BASE}/_next/static/media/font.woff2"' in result
def test_rewrites_root_hmr_path(self) -> None:
html = 'new WebSocket("wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc")'
result = rewrite(html)
assert '"wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc"' not in result
assert '"/_next/webpack-hmr?id=abc"' in result
def test_rewrites_escaped_absolute_next_font_url(self) -> None:
html = (
r'{"href":"https:\/\/craft-dev.onyx.app\/_next\/static\/media\/font.woff2"}'
)
result = rewrite(html)
assert (
r'{"href":"\/api\/build\/sessions\/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\/webapp\/_next\/static\/media\/font.woff2"}'
in result
)
class TestRuntimeFixerInjection:
def test_injects_websocket_rewrite_shim(self) -> None:
html = "<html><head></head><body></body></html>"
result = inject(html)
assert "window.WebSocket = function (url, protocols)" in result
assert f'var WEBAPP_BASE = "{BASE}"' in result
def test_injects_hmr_websocket_stub(self) -> None:
html = "<html><head></head><body></body></html>"
result = inject(html)
assert "function MockHmrWebSocket(url)" in result
assert "return new MockHmrWebSocket(rewriteNextAssetUrl(url));" in result
def test_injects_before_head_contents(self) -> None:
html = "<html><head><title>x</title></head><body></body></html>"
result = inject(html)
assert result.index(
"window.WebSocket = function (url, protocols)"
) < result.index("<title>x</title>")
def test_rewritten_hmr_url_still_matches_shim_intercept_logic(self) -> None:
html = (
"<html><head></head><body>"
'new WebSocket("wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc")'
"</body></html>"
)
rewritten = rewrite(html)
assert '"wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc"' not in rewritten
assert 'new WebSocket("/_next/webpack-hmr?id=abc")' in rewritten
injected = inject(rewritten)
assert 'new WebSocket("/_next/webpack-hmr?id=abc")' in injected
assert 'parsedUrl.pathname.indexOf("/_next/webpack-hmr") === 0' in injected
class TestProxyHeaderRewriting:
def test_rewrites_link_header_font_preload_paths(self) -> None:
headers = {
"link": (
'</_next/static/media/font.woff2>; rel=preload; as="font"; crossorigin, '
'</_next/static/media/font2.woff2>; rel=preload; as="font"; crossorigin'
)
}
result = _rewrite_proxy_response_headers(headers, SESSION_ID)
assert f"<{BASE}/_next/static/media/font.woff2>" in result["link"]
class TestProxyRequestWiring:
def test_proxy_request_rewrites_link_header_on_html_response(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
html = b"<html><head></head><body>ok</body></html>"
upstream = httpx.Response(
200,
headers={
"content-type": "text/html; charset=utf-8",
"link": '</_next/static/media/font.woff2>; rel=preload; as="font"',
},
content=html,
)
monkeypatch.setattr(api, "_get_sandbox_url", lambda *_args: "http://sandbox")
class FakeClient:
def __init__(self, *_args: object, **_kwargs: object) -> None:
pass
def __enter__(self) -> "FakeClient":
return self
def __exit__(self, *_args: object) -> Literal[False]:
return False
def get(self, _url: str, headers: dict[str, str]) -> httpx.Response:
assert "host" not in {key.lower() for key in headers}
return upstream
monkeypatch.setattr(api.httpx, "Client", FakeClient)
request = cast(Request, SimpleNamespace(headers={}, query_params=""))
response = api._proxy_request(
"", request, UUID(SESSION_ID), cast(Session, SimpleNamespace())
)
assert response.headers["link"] == (
f'<{BASE}/_next/static/media/font.woff2>; rel=preload; as="font"'
)
def test_proxy_request_injects_hmr_fixer_for_html_response(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
upstream = httpx.Response(
200,
headers={"content-type": "text/html; charset=utf-8"},
content=b"<html><head><title>x</title></head><body></body></html>",
)
monkeypatch.setattr(api, "_get_sandbox_url", lambda *_args: "http://sandbox")
class FakeClient:
def __init__(self, *_args: object, **_kwargs: object) -> None:
pass
def __enter__(self) -> "FakeClient":
return self
def __exit__(self, *_args: object) -> Literal[False]:
return False
def get(self, _url: str, headers: dict[str, str]) -> httpx.Response:
assert "host" not in {key.lower() for key in headers}
return upstream
monkeypatch.setattr(api.httpx, "Client", FakeClient)
request = cast(Request, SimpleNamespace(headers={}, query_params=""))
response = api._proxy_request(
"", request, UUID(SESSION_ID), cast(Session, SimpleNamespace())
)
body = cast(bytes, response.body).decode("utf-8")
assert "window.WebSocket = function (url, protocols)" in body
assert body.index("window.WebSocket = function (url, protocols)") < body.index(
"<title>x</title>"
)
def test_rewrites_absolute_link_header_font_preload_paths(self) -> None:
headers = {
"link": (
"<https://craft-dev.onyx.app/_next/static/media/font.woff2>; "
'rel=preload; as="font"; crossorigin'
)
}
result = _rewrite_proxy_response_headers(headers, SESSION_ID)
assert f"<{BASE}/_next/static/media/font.woff2>" in result["link"]

View File

@@ -1,400 +0,0 @@
from typing import Any
from typing import cast
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
from urllib.parse import parse_qs
from urllib.parse import urlparse
from fastapi import FastAPI
from fastapi import Response
from fastapi.testclient import TestClient
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.jwt import generate_jwt
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import GetAccessTokenError
from onyx.auth.users import CSRF_TOKEN_COOKIE_NAME
from onyx.auth.users import CSRF_TOKEN_KEY
from onyx.auth.users import get_oauth_router
from onyx.auth.users import get_pkce_cookie_name
from onyx.auth.users import PKCE_COOKIE_NAME_PREFIX
from onyx.auth.users import STATE_TOKEN_AUDIENCE
from onyx.error_handling.exceptions import register_onyx_exception_handlers
class _StubOAuthClient:
def __init__(self) -> None:
self.name = "openid"
self.authorization_calls: list[dict[str, str | list[str] | None]] = []
self.access_token_calls: list[dict[str, str | None]] = []
async def get_authorization_url(
self,
redirect_uri: str,
state: str | None = None,
scope: list[str] | None = None,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
) -> str:
self.authorization_calls.append(
{
"redirect_uri": redirect_uri,
"state": state,
"scope": scope,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
}
)
return f"https://idp.example.com/authorize?state={state}"
async def get_access_token(
self, code: str, redirect_uri: str, code_verifier: str | None = None
) -> dict[str, str | int]:
self.access_token_calls.append(
{
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
)
return {
"access_token": "oidc_access_token",
"refresh_token": "oidc_refresh_token",
"expires_at": 1730000000,
}
async def get_id_email(self, _access_token: str) -> tuple[str, str | None]:
return ("oidc_account_id", "oidc_user@example.com")
def _build_test_client(
enable_pkce: bool,
login_status_code: int = 302,
) -> tuple[TestClient, _StubOAuthClient, MagicMock]:
oauth_client = _StubOAuthClient()
transport = CookieTransport(cookie_name="testsession")
async def get_strategy() -> MagicMock:
return MagicMock()
backend = AuthenticationBackend(
name="test_backend",
transport=transport,
get_strategy=get_strategy,
)
login_response = Response(status_code=login_status_code)
if login_status_code in {301, 302, 303, 307, 308}:
login_response.headers["location"] = "/app"
login_response.set_cookie("testsession", "session-token")
backend.login = AsyncMock(return_value=login_response) # type: ignore[method-assign]
user = MagicMock()
user.is_active = True
user_manager = MagicMock()
user_manager.oauth_callback = AsyncMock(return_value=user)
user_manager.on_after_login = AsyncMock()
async def get_user_manager() -> MagicMock:
return user_manager
router = get_oauth_router(
oauth_client=cast(BaseOAuth2[Any], oauth_client),
backend=backend,
get_user_manager=get_user_manager,
state_secret="test-secret",
redirect_url="http://localhost/auth/oidc/callback",
associate_by_email=True,
is_verified_by_default=True,
enable_pkce=enable_pkce,
)
app = FastAPI()
app.include_router(router, prefix="/auth/oidc")
register_onyx_exception_handlers(app)
client = TestClient(app, raise_server_exceptions=False)
return client, oauth_client, user_manager
def _extract_state_from_authorize_response(response: Any) -> str:
auth_url = response.json()["authorization_url"]
return parse_qs(urlparse(auth_url).query)["state"][0]
def test_oidc_authorize_omits_pkce_when_flag_disabled() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=False)
response = client.get("/auth/oidc/authorize")
assert response.status_code == 200
assert oauth_client.authorization_calls[0]["code_challenge"] is None
assert oauth_client.authorization_calls[0]["code_challenge_method"] is None
assert "fastapiusersoauthcsrf" in response.cookies.keys()
assert not any(
key.startswith(PKCE_COOKIE_NAME_PREFIX) for key in response.cookies.keys()
)
def test_oidc_authorize_adds_pkce_when_flag_enabled() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
response = client.get("/auth/oidc/authorize")
assert response.status_code == 200
assert oauth_client.authorization_calls[0]["code_challenge"] is not None
assert oauth_client.authorization_calls[0]["code_challenge_method"] == "S256"
assert any(
key.startswith(PKCE_COOKIE_NAME_PREFIX) for key in response.cookies.keys()
)
def test_oidc_callback_fails_when_pkce_cookie_missing() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
for key in list(client.cookies.keys()):
if key.startswith(PKCE_COOKIE_NAME_PREFIX):
del client.cookies[key]
response = client.get(
"/auth/oidc/callback", params={"code": "abc123", "state": state}
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert oauth_client.access_token_calls == []
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_rejects_bad_state_before_token_exchange() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
client.get("/auth/oidc/authorize")
tampered_state = "not-a-valid-state-jwt"
client.cookies.set(get_pkce_cookie_name(tampered_state), "verifier123")
response = client.get(
"/auth/oidc/callback", params={"code": "abc123", "state": tampered_state}
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert oauth_client.access_token_calls == []
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_rejects_wrongly_signed_state_before_token_exchange() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
client.get("/auth/oidc/authorize")
csrf_token = client.cookies.get(CSRF_TOKEN_COOKIE_NAME)
assert csrf_token is not None
tampered_state = generate_jwt(
{
"aud": STATE_TOKEN_AUDIENCE,
CSRF_TOKEN_KEY: csrf_token,
},
"wrong-secret",
3600,
)
client.cookies.set(get_pkce_cookie_name(tampered_state), "verifier123")
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": tampered_state},
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert response.json()["detail"] == "ACCESS_TOKEN_DECODE_ERROR"
assert oauth_client.access_token_calls == []
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_rejects_csrf_mismatch_in_pkce_path() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
# Keep PKCE verifier cookie intact, but invalidate CSRF match against state JWT.
client.cookies.set("fastapiusersoauthcsrf", "wrong-csrf-token")
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": state},
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert oauth_client.access_token_calls == []
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_get_access_token_error_is_400() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
with patch.object(
oauth_client,
"get_access_token",
AsyncMock(side_effect=GetAccessTokenError("token exchange failed")),
):
response = client.get(
"/auth/oidc/callback", params={"code": "abc123", "state": state}
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert response.json()["detail"] == "Authorization code exchange failed"
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_cleans_pkce_cookie_on_idp_error_with_state() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
response = client.get(
"/auth/oidc/callback",
params={"error": "access_denied", "state": state},
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert response.json()["detail"] == "Authorization request failed or was denied"
assert oauth_client.access_token_calls == []
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_cleans_pkce_cookie_on_missing_email() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
with patch.object(
oauth_client, "get_id_email", AsyncMock(return_value=("oidc_account_id", None))
):
response = client.get(
"/auth/oidc/callback", params={"code": "abc123", "state": state}
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_rejects_wrong_audience_state_before_token_exchange() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=True)
client.get("/auth/oidc/authorize")
csrf_token = client.cookies.get(CSRF_TOKEN_COOKIE_NAME)
assert csrf_token is not None
wrong_audience_state = generate_jwt(
{
"aud": "wrong-audience",
CSRF_TOKEN_KEY: csrf_token,
},
"test-secret",
3600,
)
client.cookies.set(get_pkce_cookie_name(wrong_audience_state), "verifier123")
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": wrong_audience_state},
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert response.json()["detail"] == "ACCESS_TOKEN_DECODE_ERROR"
assert oauth_client.access_token_calls == []
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_uses_code_verifier_when_pkce_enabled() -> None:
client, oauth_client, user_manager = _build_test_client(enable_pkce=True)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
with patch(
"onyx.auth.users.fetch_ee_implementation_or_noop",
return_value=lambda _email: "tenant_1",
):
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": state},
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers.get("location") == "/"
assert oauth_client.access_token_calls[0]["code_verifier"] is not None
user_manager.oauth_callback.assert_awaited_once()
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_works_without_pkce_when_flag_disabled() -> None:
client, oauth_client, user_manager = _build_test_client(enable_pkce=False)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
with patch(
"onyx.auth.users.fetch_ee_implementation_or_noop",
return_value=lambda _email: "tenant_1",
):
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": state},
follow_redirects=False,
)
assert response.status_code == 302
assert oauth_client.access_token_calls[0]["code_verifier"] is None
user_manager.oauth_callback.assert_awaited_once()
def test_oidc_callback_pkce_preserves_redirect_when_backend_login_is_non_redirect() -> (
None
):
client, oauth_client, user_manager = _build_test_client(
enable_pkce=True,
login_status_code=200,
)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
with patch(
"onyx.auth.users.fetch_ee_implementation_or_noop",
return_value=lambda _email: "tenant_1",
):
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": state},
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers.get("location") == "/"
assert oauth_client.access_token_calls[0]["code_verifier"] is not None
user_manager.oauth_callback.assert_awaited_once()
assert "Max-Age=0" in response.headers.get("set-cookie", "")
def test_oidc_callback_non_pkce_rejects_csrf_mismatch() -> None:
client, oauth_client, _ = _build_test_client(enable_pkce=False)
authorize_response = client.get("/auth/oidc/authorize")
state = _extract_state_from_authorize_response(authorize_response)
client.cookies.set(CSRF_TOKEN_COOKIE_NAME, "wrong-csrf-token")
response = client.get(
"/auth/oidc/callback",
params={"code": "abc123", "state": state},
)
assert response.status_code == 400
assert response.json()["error_code"] == "VALIDATION_ERROR"
assert response.json()["detail"] == "OAUTH_INVALID_STATE"
# NOTE: In the non-PKCE path, oauth2_authorize_callback exchanges the code
# before route-body CSRF validation runs. This is a known ordering trade-off.
assert oauth_client.access_token_calls

View File

@@ -1,179 +0,0 @@
"""Unit tests for SharepointConnector._fetch_site_pages 404 handling.
The Graph Pages API returns 404 for classic sites or sites without
modern pages enabled. _fetch_site_pages should gracefully skip these
rather than crashing the entire indexing run.
"""
from __future__ import annotations
from typing import Any
import pytest
from requests import Response
from requests.exceptions import HTTPError
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.sharepoint.connector import SiteDescriptor
SITE_URL = "https://tenant.sharepoint.com/sites/ClassicSite"
FAKE_SITE_ID = "tenant.sharepoint.com,abc123,def456"
def _site_descriptor() -> SiteDescriptor:
return SiteDescriptor(url=SITE_URL, drive_name=None, folder_path=None)
def _make_http_error(status_code: int) -> HTTPError:
response = Response()
response.status_code = status_code
response._content = b'{"error":{"code":"itemNotFound","message":"Item not found"}}'
return HTTPError(response=response)
def _setup_connector(
monkeypatch: pytest.MonkeyPatch, # noqa: ARG001
) -> SharepointConnector:
"""Create a connector with the graph client and site resolution mocked."""
connector = SharepointConnector(sites=[SITE_URL])
connector.graph_api_base = "https://graph.microsoft.com/v1.0"
mock_sites = type(
"FakeSites",
(),
{
"get_by_url": staticmethod(
lambda url: type( # noqa: ARG005
"Q",
(),
{
"execute_query": lambda self: None, # noqa: ARG005
"id": FAKE_SITE_ID,
},
)()
),
},
)()
connector._graph_client = type("FakeGraphClient", (), {"sites": mock_sites})()
return connector
def _patch_graph_api_get_json(
monkeypatch: pytest.MonkeyPatch,
fake_fn: Any,
) -> None:
monkeypatch.setattr(SharepointConnector, "_graph_api_get_json", fake_fn)
class TestFetchSitePages404:
def test_404_yields_no_pages(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""A 404 from the Pages API should result in zero yielded pages."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(404)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert pages == []
def test_404_does_not_raise(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""A 404 must not propagate as an exception."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(404)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
for _ in connector._fetch_site_pages(_site_descriptor()):
pass
def test_non_404_http_error_still_raises(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Non-404 HTTP errors (e.g. 403) must still propagate."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(403)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
with pytest.raises(HTTPError):
list(connector._fetch_site_pages(_site_descriptor()))
def test_successful_fetch_yields_pages(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When the API succeeds, pages should be yielded normally."""
connector = _setup_connector(monkeypatch)
fake_page = {
"id": "page-1",
"title": "Hello World",
"webUrl": f"{SITE_URL}/SitePages/Hello.aspx",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
return {"value": [fake_page]}
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert len(pages) == 1
assert pages[0]["id"] == "page-1"
def test_404_on_second_page_stops_pagination(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""If the first API page succeeds but a nextLink returns 404,
already-yielded pages are kept and iteration stops cleanly."""
connector = _setup_connector(monkeypatch)
call_count = 0
first_page = {
"id": "page-1",
"title": "First",
"webUrl": f"{SITE_URL}/SitePages/First.aspx",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
nonlocal call_count
call_count += 1
if call_count == 1:
return {
"value": [first_page],
"@odata.nextLink": "https://graph.microsoft.com/next",
}
raise _make_http_error(404)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert len(pages) == 1
assert pages[0]["id"] == "page-1"

View File

@@ -0,0 +1,40 @@
"""Unit tests for the hooks feature gate."""
from unittest.mock import patch
import pytest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.api_dependencies import require_hook_enabled
class TestRequireHookEnabled:
def test_raises_when_multi_tenant(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", True),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
):
with pytest.raises(OnyxError) as exc_info:
require_hook_enabled()
assert exc_info.value.error_code is OnyxErrorCode.NOT_FOUND
assert exc_info.value.status_code == 404
assert "multi-tenant" in exc_info.value.detail
def test_raises_when_flag_disabled(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", False),
):
with pytest.raises(OnyxError) as exc_info:
require_hook_enabled()
assert exc_info.value.error_code is OnyxErrorCode.NOT_FOUND
assert exc_info.value.status_code == 404
assert "HOOK_ENABLED" in exc_info.value.detail
def test_passes_when_enabled_single_tenant(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
):
require_hook_enabled() # must not raise

View File

@@ -33,7 +33,6 @@ SECRET=
# OpenID Connect (OIDC)
#OPENID_CONFIG_URL=
#OIDC_PKCE_ENABLED=
# SAML config directory for OneLogin compatible setups
#SAML_CONF_DIR=

View File

@@ -167,7 +167,6 @@ LOG_ONYX_MODEL_INTERACTIONS=False
# OAUTH_CLIENT_ID=
# OAUTH_CLIENT_SECRET=
# OPENID_CONFIG_URL=
# OIDC_PKCE_ENABLED=
# TRACK_EXTERNAL_IDP_EXPIRY=
# CORS_ALLOWED_ORIGIN=
# INTEGRATION_TESTS_MODE=

View File

@@ -1203,8 +1203,6 @@ configMap:
# UPGRADE NOTE: Default changed from "disabled" to "basic" in 0.4.34.
# You must also set auth.userauth.values.user_auth_secret.
AUTH_TYPE: "basic"
# Enable PKCE for OIDC login flow. Leave empty/false for backward compatibility.
OIDC_PKCE_ENABLED: ""
# 1 Day Default
SESSION_EXPIRE_TIME_SECONDS: "86400"
# Can be something like onyx.app, as an extra double-check

View File

@@ -52,10 +52,6 @@
{
"scope": [],
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
},
{
"scope": [],
"content": "Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly."
}
],
"rules": [
@@ -75,14 +71,6 @@
"scope": [],
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
},
{
"scope": [],
"rule": "Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies."
},
{
"scope": [],
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
@@ -98,21 +86,6 @@
"scope": [],
"path": "CLAUDE.md",
"description": "Project instructions and coding standards"
},
{
"scope": [],
"path": "backend/alembic/README.md",
"description": "Migration guidance, including multi-tenant migration behavior"
},
{
"scope": [],
"path": "deployment/helm/charts/onyx/values-lite.yaml",
"description": "Lite deployment Helm values and service assumptions"
},
{
"scope": [],
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
}
]
}

View File

@@ -9,8 +9,6 @@ import { cn } from "@opal/utils";
type TagColor = "green" | "purple" | "blue" | "gray" | "amber";
type TagSize = "sm" | "md";
interface TagProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
@@ -20,9 +18,6 @@ interface TagProps {
/** Color variant. Default: `"gray"`. */
color?: TagColor;
/** Size variant. Default: `"sm"`. */
size?: TagSize;
}
// ---------------------------------------------------------------------------
@@ -41,11 +36,11 @@ const COLOR_CONFIG: Record<TagColor, { bg: string; text: string }> = {
// Tag
// ---------------------------------------------------------------------------
function Tag({ icon: Icon, title, color = "gray", size = "sm" }: TagProps) {
function Tag({ icon: Icon, title, color = "gray" }: TagProps) {
const config = COLOR_CONFIG[color];
return (
<div className={cn("opal-auxiliary-tag", config.bg)} data-size={size}>
<div className={cn("opal-auxiliary-tag", config.bg)}>
{Icon && (
<div className="opal-auxiliary-tag-icon-container">
<Icon className={cn("opal-auxiliary-tag-icon", config.text)} />
@@ -53,8 +48,7 @@ function Tag({ icon: Icon, title, color = "gray", size = "sm" }: TagProps) {
)}
<span
className={cn(
"opal-auxiliary-tag-title px-[2px]",
size === "md" ? "font-secondary-body" : "font-figure-small-value",
"opal-auxiliary-tag-title px-[2px] font-figure-small-value",
config.text
)}
>
@@ -64,4 +58,4 @@ function Tag({ icon: Icon, title, color = "gray", size = "sm" }: TagProps) {
);
}
export { Tag, type TagProps, type TagColor, type TagSize };
export { Tag, type TagProps, type TagColor };

View File

@@ -13,12 +13,6 @@
gap: 0;
}
.opal-auxiliary-tag[data-size="md"] {
height: 1.375rem;
padding: 0 0.375rem;
border-radius: 0.375rem;
}
.opal-auxiliary-tag-icon-container {
display: flex;
align-items: center;

View File

@@ -1,21 +0,0 @@
import type { IconProps } from "@opal/types";
const SvgFilterPlus = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M9.5 12.5L6.83334 11.1667V7.80667L1.5 1.5H14.8333L12.1667 4.65333M12.1667 7V9.5M12.1667 9.5V12M12.1667 9.5H9.66667M12.1667 9.5H14.6667"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgFilterPlus;

View File

@@ -72,7 +72,6 @@ export { default as SvgFileChartPie } from "@opal/icons/file-chart-pie";
export { default as SvgFileSmall } from "@opal/icons/file-small";
export { default as SvgFileText } from "@opal/icons/file-text";
export { default as SvgFilter } from "@opal/icons/filter";
export { default as SvgFilterPlus } from "@opal/icons/filter-plus";
export { default as SvgFold } from "@opal/icons/fold";
export { default as SvgFolder } from "@opal/icons/folder";
export { default as SvgFolderIn } from "@opal/icons/folder-in";

View File

@@ -1 +1,342 @@
export { default } from "@/refresh-pages/admin/UsersPage";
"use client";
import { useState } from "react";
import SimpleTabs from "@/refresh-components/SimpleTabs";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import InvitedUserTable from "@/components/admin/users/InvitedUserTable";
import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable";
import Modal from "@/refresh-components/Modal";
import { ThreeDotsLoader } from "@/components/Loading";
import { toast } from "@/hooks/useToast";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import { ErrorCallout } from "@/components/ErrorCallout";
import BulkAdd, { EmailInviteStatus } from "@/components/admin/users/BulkAdd";
import Text from "@/refresh-components/texts/Text";
import { InvitedUserSnapshot } from "@/lib/types";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import PendingUsersTable from "@/components/admin/users/PendingUsersTable";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { Spinner } from "@/components/Spinner";
import { SvgDownloadCloud, SvgUserPlus } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.USERS]!;
interface CountDisplayProps {
label: string;
value: number | null;
isLoading: boolean;
}
function CountDisplay({ label, value, isLoading }: CountDisplayProps) {
const displayValue = isLoading
? "..."
: value === null
? "-"
: value.toLocaleString();
return (
<div className="flex items-center gap-1 px-1 py-0.5 rounded-06">
<Text as="p" mainUiMuted text03>
{label}
</Text>
<Text as="p" headingH3 text05>
{displayValue}
</Text>
</div>
);
}
function UsersTables({
q,
isDownloadingUsers,
setIsDownloadingUsers,
}: {
q: string;
isDownloadingUsers: boolean;
setIsDownloadingUsers: (loading: boolean) => void;
}) {
const [currentUsersCount, setCurrentUsersCount] = useState<number | null>(
null
);
const [currentUsersLoading, setCurrentUsersLoading] = useState<boolean>(true);
const downloadAllUsers = async () => {
setIsDownloadingUsers(true);
const startTime = Date.now();
const minDurationMsForSpinner = 1000;
try {
const response = await fetch("/api/manage/users/download");
if (!response.ok) {
throw new Error("Failed to download all users");
}
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
const anchor_tag = document.createElement("a");
anchor_tag.href = url;
anchor_tag.download = "users.csv";
document.body.appendChild(anchor_tag);
anchor_tag.click();
//Clean up URL after download to avoid memory leaks
window.URL.revokeObjectURL(url);
document.body.removeChild(anchor_tag);
} catch (error) {
toast.error(`Failed to download all users - ${error}`);
} finally {
//Ensure spinner is visible for at least 1 second
//This is to avoid the spinner disappearing too quickly
const endTime = Date.now();
const duration = endTime - startTime;
await new Promise((resolve) =>
setTimeout(resolve, minDurationMsForSpinner - duration)
);
setIsDownloadingUsers(false);
}
};
const {
data: invitedUsers,
error: invitedUsersError,
isLoading: invitedUsersLoading,
mutate: invitedUsersMutate,
} = useSWR<InvitedUserSnapshot[]>(
"/api/manage/users/invited",
errorHandlingFetcher
);
const { data: validDomains, error: domainsError } = useSWR<string[]>(
"/api/manage/admin/valid-domains",
errorHandlingFetcher
);
const {
data: pendingUsers,
error: pendingUsersError,
isLoading: pendingUsersLoading,
mutate: pendingUsersMutate,
} = useSWR<InvitedUserSnapshot[]>(
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
errorHandlingFetcher
);
const invitedUsersCount =
invitedUsers === undefined ? null : invitedUsers.length;
const pendingUsersCount =
pendingUsers === undefined ? null : pendingUsers.length;
// Show loading animation only during the initial data fetch
if (!validDomains) {
return <ThreeDotsLoader />;
}
if (domainsError) {
return (
<ErrorCallout
errorTitle="Error loading valid domains"
errorMsg={domainsError?.info?.detail}
/>
);
}
const tabs = SimpleTabs.generateTabs({
current: {
name: "Current Users",
content: (
<Card className="w-full">
<CardHeader>
<div className="flex justify-between items-center gap-1">
<CardTitle>Current Users</CardTitle>
<Disabled disabled={isDownloadingUsers}>
<Button
icon={SvgDownloadCloud}
onClick={() => downloadAllUsers()}
>
{isDownloadingUsers ? "Downloading..." : "Download CSV"}
</Button>
</Disabled>
</div>
</CardHeader>
<CardContent>
<SignedUpUserTable
invitedUsers={invitedUsers || []}
q={q}
invitedUsersMutate={invitedUsersMutate}
countDisplay={
<CountDisplay
label="Total users"
value={currentUsersCount}
isLoading={currentUsersLoading}
/>
}
onTotalItemsChange={(count) => setCurrentUsersCount(count)}
onLoadingChange={(loading) => {
setCurrentUsersLoading(loading);
if (loading) {
setCurrentUsersCount(null);
}
}}
/>
</CardContent>
</Card>
),
},
invited: {
name: "Invited Users",
content: (
<Card className="w-full">
<CardHeader>
<div className="flex justify-between items-center gap-1">
<CardTitle>Invited Users</CardTitle>
<CountDisplay
label="Total invited"
value={invitedUsersCount}
isLoading={invitedUsersLoading}
/>
</div>
</CardHeader>
<CardContent>
<InvitedUserTable
users={invitedUsers || []}
mutate={invitedUsersMutate}
error={invitedUsersError}
isLoading={invitedUsersLoading}
q={q}
/>
</CardContent>
</Card>
),
},
...(NEXT_PUBLIC_CLOUD_ENABLED && {
pending: {
name: "Pending Users",
content: (
<Card>
<CardHeader>
<div className="flex justify-between items-center gap-1">
<CardTitle>Pending Users</CardTitle>
<CountDisplay
label="Total pending"
value={pendingUsersCount}
isLoading={pendingUsersLoading}
/>
</div>
</CardHeader>
<CardContent>
<PendingUsersTable
users={pendingUsers || []}
mutate={pendingUsersMutate}
error={pendingUsersError}
isLoading={pendingUsersLoading}
q={q}
/>
</CardContent>
</Card>
),
},
}),
});
return <SimpleTabs tabs={tabs} defaultValue="current" />;
}
function SearchableTables() {
const [query, setQuery] = useState("");
const [isDownloadingUsers, setIsDownloadingUsers] = useState(false);
return (
<div>
{isDownloadingUsers && <Spinner />}
<div className="flex flex-col gap-y-4">
<div className="flex flex-row items-center gap-2">
<InputTypeIn
placeholder="Search"
value={query}
onChange={(event) => setQuery(event.target.value)}
/>
<AddUserButton />
</div>
<UsersTables
q={query}
isDownloadingUsers={isDownloadingUsers}
setIsDownloadingUsers={setIsDownloadingUsers}
/>
</div>
</div>
);
}
function AddUserButton() {
const [bulkAddUsersModal, setBulkAddUsersModal] = useState(false);
const onSuccess = (emailInviteStatus: EmailInviteStatus) => {
mutate(
(key) => typeof key === "string" && key.startsWith("/api/manage/users")
);
setBulkAddUsersModal(false);
if (emailInviteStatus === "NOT_CONFIGURED") {
toast.warning(
"Users added, but no email notification was sent. There is no SMTP server set up for email sending."
);
} else if (emailInviteStatus === "SEND_FAILED") {
toast.warning(
"Users added, but email sending failed. Check your SMTP configuration and try again."
);
} else {
toast.success("Users invited!");
}
};
const onFailure = async (res: Response) => {
const error = (await res.json()).detail;
toast.error(`Failed to invite users - ${error}`);
};
const handleInviteClick = () => {
setBulkAddUsersModal(true);
};
return (
<>
<CreateButton primary onClick={handleInviteClick}>
Invite Users
</CreateButton>
{bulkAddUsersModal && (
<Modal open onOpenChange={() => setBulkAddUsersModal(false)}>
<Modal.Content>
<Modal.Header
icon={SvgUserPlus}
title="Bulk Add Users"
onClose={() => setBulkAddUsersModal(false)}
/>
<Modal.Body>
<div className="flex flex-col gap-2">
<Text as="p">
Add the email addresses to import, separated by whitespaces.
Invited users will be able to login to this domain with their
email address.
</Text>
<BulkAdd onSuccess={onSuccess} onFailure={onFailure} />
</div>
</Modal.Body>
</Modal.Content>
</Modal>
)}
</>
);
}
export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
<SettingsLayouts.Body>
<SearchableTables />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -0,0 +1 @@
export { default } from "@/refresh-pages/admin/UsersPage";

View File

@@ -31,6 +31,7 @@ const SETTINGS_LAYOUT_PREFIXES = [
ADMIN_PATHS.LLM_MODELS,
ADMIN_PATHS.AGENTS,
ADMIN_PATHS.USERS,
ADMIN_PATHS.USERS_V2,
ADMIN_PATHS.TOKEN_RATE_LIMITS,
ADMIN_PATHS.SEARCH_SETTINGS,
ADMIN_PATHS.DOCUMENT_PROCESSING,

View File

@@ -136,14 +136,13 @@ function HorizontalInputLayout({
justifyContent="between"
alignItems={center ? "center" : "start"}
>
<div className="flex flex-col flex-1 min-w-0 self-stretch">
<div className="flex flex-col flex-1 self-stretch">
<Content
title={title}
description={description}
optional={optional}
sizePreset={sizePreset}
variant="section"
widthVariant="full"
/>
</div>
<div className="flex flex-col items-end">{children}</div>

View File

@@ -58,6 +58,7 @@ export const ADMIN_PATHS = {
DOCUMENT_PROCESSING: "/admin/configuration/document-processing",
KNOWLEDGE_GRAPH: "/admin/kg",
USERS: "/admin/users",
USERS_V2: "/admin/users2",
API_KEYS: "/admin/api-key",
TOKEN_RATE_LIMITS: "/admin/token-rate-limits",
USAGE: "/admin/performance/usage",
@@ -187,9 +188,14 @@ export const ADMIN_ROUTE_CONFIG: Record<string, AdminRouteConfig> = {
},
[ADMIN_PATHS.USERS]: {
icon: SvgUser,
title: "Users & Requests",
title: "Manage Users",
sidebarLabel: "Users",
},
[ADMIN_PATHS.USERS_V2]: {
icon: SvgUser,
title: "Users & Requests",
sidebarLabel: "Users v2",
},
[ADMIN_PATHS.API_KEYS]: {
icon: SvgKey,
title: "API Keys",

View File

@@ -79,7 +79,7 @@ export const USER_STATUS_LABELS: Record<UserStatus, string> = {
[UserStatus.ACTIVE]: "Active",
[UserStatus.INACTIVE]: "Inactive",
[UserStatus.INVITED]: "Invite Pending",
[UserStatus.REQUESTED]: "Request to Join",
[UserStatus.REQUESTED]: "Requested",
};
export const INVALID_ROLE_HOVER_TEXT: Partial<Record<UserRole, string>> = {

View File

@@ -133,7 +133,6 @@ export default function DataTable<TData>(props: DataTableProps<TData>) {
height,
headerBackground,
serverSide,
emptyState,
} = props;
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
@@ -274,7 +273,6 @@ export default function DataTable<TData>(props: DataTableProps<TData>) {
currentPage={currentPage}
totalPages={totalPages}
onPageChange={setPage}
leftExtra={footerConfig.leftExtra}
/>
);
}
@@ -303,25 +301,7 @@ export default function DataTable<TData>(props: DataTableProps<TData>) {
: undefined),
}}
>
<Table
width={
Object.keys(columnWidths).length > 0
? Object.values(columnWidths).reduce((sum, w) => sum + w, 0)
: undefined
}
>
<colgroup>
{table.getAllLeafColumns().map((col) => (
<col
key={col.id}
style={
columnWidths[col.id] != null
? { width: columnWidths[col.id] }
: undefined
}
/>
))}
</colgroup>
<Table>
<TableHeader>
{table.getHeaderGroups().map((headerGroup) => (
<TableRow key={headerGroup.id}>
@@ -448,13 +428,6 @@ export default function DataTable<TData>(props: DataTableProps<TData>) {
: undefined
}
>
{emptyState && table.getRowModel().rows.length === 0 && (
<tr>
<td colSpan={table.getVisibleLeafColumns().length}>
{emptyState}
</td>
</tr>
)}
{table.getRowModel().rows.map((row) => {
const rowId = hasDraggable ? getRowId(row.original) : undefined;

View File

@@ -61,8 +61,6 @@ interface FooterSummaryModeProps {
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** Optional extra element rendered after the summary text (e.g. a download icon). */
leftExtra?: React.ReactNode;
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
size?: TableSize;
className?: string;
@@ -117,15 +115,12 @@ export default function Footer(props: FooterProps) {
isSmall={isSmall}
/>
) : (
<>
<SummaryLeft
rangeStart={props.rangeStart}
rangeEnd={props.rangeEnd}
totalItems={props.totalItems}
isSmall={isSmall}
/>
{props.leftExtra}
</>
<SummaryLeft
rangeStart={props.rangeStart}
rangeEnd={props.rangeEnd}
totalItems={props.totalItems}
isSmall={isSmall}
/>
)}
</div>

View File

@@ -21,13 +21,13 @@ export default function TableCell({
const resolvedSize = size ?? contextSize;
return (
<td
className="tbl-cell overflow-hidden"
className="tbl-cell"
data-size={resolvedSize}
style={width != null ? { width } : undefined}
{...props}
>
<div
className={cn("tbl-cell-inner", "flex items-center overflow-hidden")}
className={cn("tbl-cell-inner", "flex items-center")}
data-size={resolvedSize}
>
{children}

View File

@@ -141,8 +141,6 @@ export interface DataTableFooterSelection {
export interface DataTableFooterSummary {
mode: "summary";
/** Optional extra element rendered after the summary text (e.g. a download icon). */
leftExtra?: ReactNode;
}
export type DataTableFooterConfig =
@@ -192,6 +190,4 @@ export interface DataTableProps<TData> {
* - Fires separate callbacks for sorting, pagination, and search changes
*/
serverSide?: ServerSideConfig;
/** Content to render inside the table body when there are no rows. */
emptyState?: React.ReactNode;
}

View File

@@ -1,332 +0,0 @@
"use client";
import { useState, useMemo, useRef, useCallback } from "react";
import { Button } from "@opal/components";
import { SvgUsers, SvgUser, SvgLogOut, SvgCheck } from "@opal/icons";
import { Disabled } from "@opal/core";
import { ContentAction } from "@opal/layouts";
import Modal from "@/refresh-components/Modal";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import LineItem from "@/refresh-components/buttons/LineItem";
import Separator from "@/refresh-components/Separator";
import ShadowDiv from "@/refresh-components/ShadowDiv";
import { Section } from "@/layouts/general-layouts";
import { toast } from "@/hooks/useToast";
import { UserRole, USER_ROLE_LABELS } from "@/lib/types";
import useGroups from "@/hooks/useGroups";
import { addUserToGroup, removeUserFromGroup, setUserRole } from "./svc";
import type { UserRow } from "./interfaces";
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
const ASSIGNABLE_ROLES: UserRole[] = [
UserRole.ADMIN,
UserRole.GLOBAL_CURATOR,
UserRole.BASIC,
];
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface EditGroupsModalProps {
user: UserRow & { id: string };
onClose: () => void;
onMutate: () => void;
}
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
export default function EditGroupsModal({
user,
onClose,
onMutate,
}: EditGroupsModalProps) {
const { data: allGroups, isLoading: groupsLoading } = useGroups();
const [searchTerm, setSearchTerm] = useState("");
const [dropdownOpen, setDropdownOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const containerRef = useRef<HTMLDivElement>(null);
const closeDropdown = useCallback(() => {
// Delay to allow click events on dropdown items to fire before closing
setTimeout(() => {
if (!containerRef.current?.contains(document.activeElement)) {
setDropdownOpen(false);
}
}, 0);
}, []);
const [selectedRole, setSelectedRole] = useState<UserRole | "">(
user.role ?? ""
);
const initialMemberGroupIds = useMemo(
() => new Set(user.groups.map((g) => g.id)),
[user.groups]
);
const [memberGroupIds, setMemberGroupIds] = useState<Set<number>>(
() => new Set(initialMemberGroupIds)
);
// Dropdown shows all groups filtered by search term
const dropdownGroups = useMemo(() => {
if (!allGroups) return [];
if (searchTerm.length === 0) return allGroups;
const lower = searchTerm.toLowerCase();
return allGroups.filter((g) => g.name.toLowerCase().includes(lower));
}, [allGroups, searchTerm]);
// Joined groups shown in the modal body
const joinedGroups = useMemo(() => {
if (!allGroups) return [];
return allGroups.filter((g) => memberGroupIds.has(g.id));
}, [allGroups, memberGroupIds]);
const hasGroupChanges = useMemo(() => {
if (memberGroupIds.size !== initialMemberGroupIds.size) return true;
return Array.from(memberGroupIds).some(
(id) => !initialMemberGroupIds.has(id)
);
}, [memberGroupIds, initialMemberGroupIds]);
const hasRoleChange =
user.role !== null && selectedRole !== "" && selectedRole !== user.role;
const hasChanges = hasGroupChanges || hasRoleChange;
const toggleGroup = (groupId: number) => {
setMemberGroupIds((prev) => {
const next = new Set(prev);
if (next.has(groupId)) {
next.delete(groupId);
} else {
next.add(groupId);
}
return next;
});
};
const handleSave = async () => {
setIsSubmitting(true);
try {
const toAdd = Array.from(memberGroupIds).filter(
(id) => !initialMemberGroupIds.has(id)
);
const toRemove = Array.from(initialMemberGroupIds).filter(
(id) => !memberGroupIds.has(id)
);
if (user.id) {
for (const groupId of toAdd) {
await addUserToGroup(groupId, user.id);
}
for (const groupId of toRemove) {
const group = allGroups?.find((g) => g.id === groupId);
if (group) {
const currentUserIds = group.users.map((u) => u.id);
const ccPairIds = group.cc_pairs.map((cc) => cc.id);
await removeUserFromGroup(
groupId,
currentUserIds,
user.id,
ccPairIds
);
}
}
}
if (
user.role !== null &&
selectedRole !== "" &&
selectedRole !== user.role
) {
await setUserRole(user.email, selectedRole);
}
onMutate();
toast.success("User updated");
onClose();
} catch (err) {
onMutate(); // refresh to show partially-applied state
toast.error(err instanceof Error ? err.message : "An error occurred");
} finally {
setIsSubmitting(false);
}
};
const displayName = user.personal_name ?? user.email;
return (
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
<Modal.Content width="sm">
<Modal.Header
icon={SvgUsers}
title="Edit User's Groups & Roles"
description={
user.personal_name
? `${user.personal_name} (${user.email})`
: user.email
}
onClose={onClose}
/>
<Modal.Body twoTone>
<Section
gap={1}
height="auto"
alignItems="stretch"
justifyContent="start"
>
{/* Subsection: white card behind search + groups */}
<div className="relative">
<div className="absolute -inset-2 bg-background-neutral-00 rounded-12" />
<Section
gap={0.5}
height="auto"
alignItems="stretch"
justifyContent="start"
>
<div ref={containerRef} className="relative">
<InputTypeIn
value={searchTerm}
onChange={(e) => {
setSearchTerm(e.target.value);
if (!dropdownOpen) setDropdownOpen(true);
}}
onFocus={() => setDropdownOpen(true)}
onBlur={closeDropdown}
placeholder="Search groups to join..."
leftSearchIcon
/>
{dropdownOpen && (
<div className="absolute top-full left-0 right-0 z-50 mt-1 bg-background-neutral-00 border border-border-02 rounded-12 shadow-md p-1">
{groupsLoading ? (
<Text as="p" text03 secondaryBody className="px-3 py-2">
Loading groups...
</Text>
) : dropdownGroups.length === 0 ? (
<Text as="p" text03 secondaryBody className="px-3 py-2">
No groups found
</Text>
) : (
<ShadowDiv className="max-h-[200px] flex flex-col gap-1">
{dropdownGroups.map((group) => {
const isMember = memberGroupIds.has(group.id);
return (
<LineItem
key={group.id}
icon={isMember ? SvgCheck : SvgUsers}
description={`${group.users.length} ${
group.users.length === 1 ? "user" : "users"
}`}
selected={isMember}
emphasized={isMember}
onMouseDown={(e: React.MouseEvent) =>
e.preventDefault()
}
onClick={() => toggleGroup(group.id)}
>
{group.name}
</LineItem>
);
})}
</ShadowDiv>
)}
</div>
)}
</div>
{joinedGroups.length === 0 ? (
<LineItem
icon={SvgUsers}
description={`${displayName} is not in any groups.`}
muted
>
No groups joined
</LineItem>
) : (
<ShadowDiv className="flex flex-col gap-1 max-h-[200px]">
{joinedGroups.map((group) => (
<div
key={group.id}
className="bg-background-tint-01 rounded-08"
>
<LineItem
icon={SvgUsers}
description={`${group.users.length} ${
group.users.length === 1 ? "user" : "users"
}`}
rightChildren={
<SvgLogOut className="w-4 h-4 text-text-03" />
}
onClick={() => toggleGroup(group.id)}
>
{group.name}
</LineItem>
</div>
))}
</ShadowDiv>
)}
</Section>
</div>
{user.role && (
<>
<Separator noPadding />
<ContentAction
title="User Role"
description="This controls their general permissions."
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
rightChildren={
<InputSelect
value={selectedRole}
onValueChange={(v) => setSelectedRole(v as UserRole)}
>
<InputSelect.Trigger />
<InputSelect.Content>
{user.role && !ASSIGNABLE_ROLES.includes(user.role) && (
<InputSelect.Item
key={user.role}
value={user.role}
icon={SvgUser}
>
{USER_ROLE_LABELS[user.role]}
</InputSelect.Item>
)}
{ASSIGNABLE_ROLES.map((role) => (
<InputSelect.Item
key={role}
value={role}
icon={SvgUser}
>
{USER_ROLE_LABELS[role]}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputSelect>
}
/>
</>
)}
</Section>
</Modal.Body>
<Modal.Footer>
<Button prominence="secondary" onClick={onClose}>
Cancel
</Button>
<Disabled disabled={isSubmitting || !hasChanges}>
<Button onClick={handleSave}>Save Changes</Button>
</Disabled>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,195 +0,0 @@
"use client";
import {
useState,
useRef,
useLayoutEffect,
useCallback,
useEffect,
} from "react";
import { SvgEdit } from "@opal/icons";
import { Tag } from "@opal/components";
import IconButton from "@/refresh-components/buttons/IconButton";
import Text from "@/refresh-components/texts/Text";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import EditGroupsModal from "./EditGroupsModal";
import type { UserRow, UserGroupInfo } from "./interfaces";
interface GroupsCellProps {
groups: UserGroupInfo[];
user: UserRow;
onMutate: () => void;
}
/**
* Measures how many Tag pills fit in the container, accounting for a "+N"
* overflow counter when not all tags are visible. Uses a two-phase render:
* first renders all tags (clipped by overflow:hidden) for measurement, then
* re-renders with only the visible subset + "+N".
*
* Hovering the cell shows a tooltip with ALL groups. Clicking opens the
* edit groups modal.
*/
export default function GroupsCell({
groups,
user,
onMutate,
}: GroupsCellProps) {
const [showModal, setShowModal] = useState(false);
const [visibleCount, setVisibleCount] = useState<number | null>(null);
const containerRef = useRef<HTMLDivElement>(null);
const computeVisibleCount = useCallback(() => {
const container = containerRef.current;
if (!container || groups.length <= 1) {
setVisibleCount(groups.length);
return;
}
const tags = container.querySelectorAll<HTMLElement>("[data-group-tag]");
if (tags.length === 0) return;
const containerWidth = container.clientWidth;
const gap = 4; // gap-1
const counterWidth = 32; // "+N" Tag approximate width
let used = 0;
let count = 0;
for (let i = 0; i < tags.length; i++) {
const tagWidth = tags[i]!.offsetWidth;
const gapBefore = count > 0 ? gap : 0;
const hasMore = i < tags.length - 1;
const reserve = hasMore ? gap + counterWidth : 0;
if (used + gapBefore + tagWidth + reserve <= containerWidth) {
used += gapBefore + tagWidth;
count++;
} else {
break;
}
}
setVisibleCount(Math.max(1, count));
}, [groups]);
// Reset to measurement phase when groups change
useLayoutEffect(() => {
setVisibleCount(null);
}, [groups]);
// Measure after the "show all" render
useLayoutEffect(() => {
if (visibleCount !== null) return;
computeVisibleCount();
}, [visibleCount, computeVisibleCount]);
// Re-measure when the container width changes (e.g. window resize).
// Track width so height-only changes (from the measurement cycle toggling
// visible tags) don't cause an infinite render loop.
const lastWidthRef = useRef(0);
useEffect(() => {
const node = containerRef.current;
if (!node) return;
const observer = new ResizeObserver((entries) => {
const width = entries[0]?.contentRect.width ?? 0;
if (Math.abs(width - lastWidthRef.current) < 1) return;
lastWidthRef.current = width;
setVisibleCount(null);
});
observer.observe(node);
return () => observer.disconnect();
}, [groups]);
const isMeasuring = visibleCount === null;
const effectiveVisible = visibleCount ?? groups.length;
const overflowCount = groups.length - effectiveVisible;
const hasOverflow = !isMeasuring && overflowCount > 0;
const allGroupsTooltip = (
<div className="flex flex-wrap gap-1 max-w-[14rem]">
{groups.map((g) => (
<div key={g.id} className="max-w-[10rem]">
<Tag title={g.name} size="md" />
</div>
))}
</div>
);
const tagsContent = (
<>
{(isMeasuring ? groups : groups.slice(0, effectiveVisible)).map((g) => (
<div key={g.id} data-group-tag className="flex-shrink-0">
<Tag title={g.name} size="md" />
</div>
))}
{hasOverflow && (
<div className="flex-shrink-0">
<Tag title={`+${overflowCount}`} size="md" />
</div>
)}
</>
);
return (
<>
<div
className={`group/groups relative flex items-center w-full min-w-0 ${
user.id ? "cursor-pointer" : ""
}`}
onClick={user.id ? () => setShowModal(true) : undefined}
>
{groups.length === 0 ? (
<div
ref={containerRef}
className="flex items-center gap-1 overflow-hidden flex-nowrap min-w-0 pr-7"
>
<Text as="span" secondaryBody text03>
</Text>
</div>
) : (
<SimpleTooltip
side="bottom"
align="start"
tooltip={allGroupsTooltip}
disabled={!hasOverflow}
className="bg-background-neutral-01 shadow-sm"
delayDuration={200}
>
<div
ref={containerRef}
className="flex items-center gap-1 overflow-hidden flex-nowrap min-w-0 pr-7"
>
{tagsContent}
</div>
</SimpleTooltip>
)}
{user.id && (
<IconButton
tertiary
icon={SvgEdit}
tooltip="Edit"
toolTipPosition="left"
tooltipSize="sm"
className="absolute right-0 opacity-0 group-hover/groups:opacity-100 transition-opacity"
onClick={(e) => {
e.stopPropagation();
setShowModal(true);
}}
/>
)}
</div>
{showModal && user.id != null && (
<EditGroupsModal
user={{ ...user, id: user.id }}
onClose={() => setShowModal(false)}
onMutate={onMutate}
/>
)}
</>
);
}

View File

@@ -1,20 +1,14 @@
"use client";
import { useState } from "react";
import {
SvgCheck,
SvgSlack,
SvgUser,
SvgUserManage,
SvgUsers,
} from "@opal/icons";
import { SvgCheck, SvgSlack, SvgUser, SvgUsers } from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
import FilterButton from "@/refresh-components/buttons/FilterButton";
import Popover from "@/refresh-components/Popover";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import LineItem from "@/refresh-components/buttons/LineItem";
import Text from "@/refresh-components/texts/Text";
import ShadowDiv from "@/refresh-components/ShadowDiv";
import Separator from "@/refresh-components/Separator";
import {
UserRole,
UserStatus,
@@ -24,20 +18,29 @@ import {
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import type { GroupOption, StatusFilter, StatusCountMap } from "./interfaces";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface UserFiltersProps {
selectedRoles: UserRole[];
onRolesChange: (roles: UserRole[]) => void;
selectedGroups: number[];
onGroupsChange: (groupIds: number[]) => void;
groups: GroupOption[];
selectedStatuses: StatusFilter;
onStatusesChange: (statuses: StatusFilter) => void;
roleCounts: Record<string, number>;
statusCounts: StatusCountMap;
}
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
const VISIBLE_FILTER_ROLES: UserRole[] = [
UserRole.ADMIN,
UserRole.GLOBAL_CURATOR,
UserRole.BASIC,
UserRole.SLACK_USER,
];
const FILTERABLE_ROLES = VISIBLE_FILTER_ROLES.map(
(role) => [role, USER_ROLE_LABELS[role]] as [UserRole, string]
);
const FILTERABLE_ROLES = Object.entries(USER_ROLE_LABELS).filter(
([role]) => role !== UserRole.EXT_PERM_USER
) as [UserRole, string][];
const FILTERABLE_STATUSES = (
Object.entries(USER_STATUS_LABELS) as [UserStatus, string][]
@@ -46,7 +49,6 @@ const FILTERABLE_STATUSES = (
);
const ROLE_ICONS: Partial<Record<UserRole, IconFunctionComponent>> = {
[UserRole.ADMIN]: SvgUserManage,
[UserRole.SLACK_USER]: SvgSlack,
};
@@ -74,18 +76,6 @@ function CountBadge({ count }: { count: number | undefined }) {
// Component
// ---------------------------------------------------------------------------
interface UserFiltersProps {
selectedRoles: UserRole[];
onRolesChange: (roles: UserRole[]) => void;
selectedGroups: number[];
onGroupsChange: (groupIds: number[]) => void;
groups: GroupOption[];
selectedStatuses: StatusFilter;
onStatusesChange: (statuses: StatusFilter) => void;
roleCounts: Record<string, number>;
statusCounts: StatusCountMap;
}
export default function UserFilters({
selectedRoles,
onRolesChange,
@@ -111,22 +101,6 @@ export default function UserFilters({
}
};
const toggleGroup = (groupId: number) => {
if (selectedGroups.includes(groupId)) {
onGroupsChange(selectedGroups.filter((id) => id !== groupId));
} else {
onGroupsChange([...selectedGroups, groupId]);
}
};
const toggleStatus = (status: UserStatus) => {
if (selectedStatuses.includes(status)) {
onStatusesChange(selectedStatuses.filter((s) => s !== status));
} else {
onStatusesChange([...selectedStatuses, status]);
}
};
const roleLabel = hasRoleFilter
? FILTERABLE_ROLES.filter(([role]) => selectedRoles.includes(role))
.map(([, label]) => label)
@@ -135,6 +109,14 @@ export default function UserFilters({
(selectedRoles.length > 2 ? `, +${selectedRoles.length - 2}` : "")
: "All Account Types";
const toggleGroup = (groupId: number) => {
if (selectedGroups.includes(groupId)) {
onGroupsChange(selectedGroups.filter((id) => id !== groupId));
} else {
onGroupsChange([...selectedGroups, groupId]);
}
};
const groupLabel = hasGroupFilter
? groups
.filter((g) => selectedGroups.includes(g.id))
@@ -144,6 +126,14 @@ export default function UserFilters({
(selectedGroups.length > 2 ? `, +${selectedGroups.length - 2}` : "")
: "All Groups";
const toggleStatus = (status: UserStatus) => {
if (selectedStatuses.includes(status)) {
onStatusesChange(selectedStatuses.filter((s) => s !== status));
} else {
onStatusesChange([...selectedStatuses, status]);
}
};
const statusLabel = hasStatusFilter
? FILTERABLE_STATUSES.filter(([status]) =>
selectedStatuses.includes(status)
@@ -176,13 +166,13 @@ export default function UserFilters({
<Popover.Content align="start">
<div className="flex flex-col gap-1 p-1 min-w-[200px]">
<LineItem
icon={!hasRoleFilter ? SvgCheck : SvgUsers}
icon={SvgUsers}
selected={!hasRoleFilter}
emphasized={!hasRoleFilter}
onClick={() => onRolesChange([])}
>
All Account Types
</LineItem>
<Separator noPadding />
{FILTERABLE_ROLES.map(([role, label]) => {
const isSelected = selectedRoles.includes(role);
const roleIcon = ROLE_ICONS[role] ?? SvgUser;
@@ -191,7 +181,6 @@ export default function UserFilters({
key={role}
icon={isSelected ? SvgCheck : roleIcon}
selected={isSelected}
emphasized={isSelected}
onClick={() => toggleRole(role)}
rightChildren={<CountBadge count={roleCounts[role]} />}
>
@@ -222,30 +211,30 @@ export default function UserFilters({
</Popover.Trigger>
<Popover.Content align="start">
<div className="flex flex-col gap-1 p-1 min-w-[200px]">
<InputTypeIn
value={groupSearch}
onChange={(e) => setGroupSearch(e.target.value)}
placeholder="Search groups..."
leftSearchIcon
variant="internal"
/>
<div className="px-1 pt-1">
<InputTypeIn
value={groupSearch}
onChange={(e) => setGroupSearch(e.target.value)}
placeholder="Search groups..."
leftSearchIcon
/>
</div>
<LineItem
icon={!hasGroupFilter ? SvgCheck : SvgUsers}
icon={SvgUsers}
selected={!hasGroupFilter}
emphasized={!hasGroupFilter}
onClick={() => onGroupsChange([])}
>
All Groups
</LineItem>
<ShadowDiv className="flex flex-col gap-1 max-h-[240px]">
<Separator noPadding />
<div className="flex flex-col gap-1 max-h-[240px] overflow-y-auto">
{filteredGroups.map((group) => {
const isSelected = selectedGroups.includes(group.id);
return (
<LineItem
key={group.id}
icon={isSelected ? SvgCheck : SvgUsers}
icon={isSelected ? SvgCheck : undefined}
selected={isSelected}
emphasized={isSelected}
onClick={() => toggleGroup(group.id)}
rightChildren={<CountBadge count={group.memberCount} />}
>
@@ -258,7 +247,7 @@ export default function UserFilters({
No groups found
</Text>
)}
</ShadowDiv>
</div>
</div>
</Popover.Content>
</Popover>
@@ -277,22 +266,21 @@ export default function UserFilters({
<Popover.Content align="start">
<div className="flex flex-col gap-1 p-1 min-w-[200px]">
<LineItem
icon={!hasStatusFilter ? SvgCheck : SvgUser}
icon={!hasStatusFilter ? SvgCheck : undefined}
selected={!hasStatusFilter}
emphasized={!hasStatusFilter}
onClick={() => onStatusesChange([])}
>
All Status
</LineItem>
<Separator noPadding />
{FILTERABLE_STATUSES.map(([status, label]) => {
const isSelected = selectedStatuses.includes(status);
const countKey = STATUS_COUNT_KEY[status];
return (
<LineItem
key={status}
icon={isSelected ? SvgCheck : SvgUser}
icon={isSelected ? SvgCheck : undefined}
selected={isSelected}
emphasized={isSelected}
onClick={() => toggleStatus(status)}
rightChildren={<CountBadge count={statusCounts[countKey]} />}
>

View File

@@ -90,46 +90,44 @@ export default function UserRoleCell({ user, onMutate }: UserRoleCellProps) {
const currentIcon = ROLE_ICONS[user.role] ?? SvgUser;
return (
<div className="[&_button]:rounded-08">
<Disabled disabled={isUpdating}>
<Popover open={open} onOpenChange={setOpen}>
<Popover.Trigger asChild>
<OpenButton
icon={currentIcon}
variant="select-tinted"
width="full"
justifyContent="between"
>
{USER_ROLE_LABELS[user.role]}
</OpenButton>
</Popover.Trigger>
<Popover.Content align="start">
<div className="flex flex-col gap-1 p-1 min-w-[160px]">
{SELECTABLE_ROLES.map((role) => {
if (
role === UserRole.GLOBAL_CURATOR &&
!isPaidEnterpriseFeaturesEnabled
) {
return null;
}
const isSelected = user.role === role;
const icon = ROLE_ICONS[role] ?? SvgUser;
return (
<LineItem
key={role}
icon={isSelected ? SvgCheck : icon}
selected={isSelected}
emphasized={isSelected}
onClick={() => handleSelect(role)}
>
{USER_ROLE_LABELS[role]}
</LineItem>
);
})}
</div>
</Popover.Content>
</Popover>
</Disabled>
</div>
<Disabled disabled={isUpdating}>
<Popover open={open} onOpenChange={setOpen}>
<Popover.Trigger asChild>
<OpenButton
icon={currentIcon}
variant="select-tinted"
width="full"
justifyContent="between"
>
{USER_ROLE_LABELS[user.role]}
</OpenButton>
</Popover.Trigger>
<Popover.Content align="start">
<div className="flex flex-col gap-1 p-1 min-w-[160px]">
{SELECTABLE_ROLES.map((role) => {
if (
role === UserRole.GLOBAL_CURATOR &&
!isPaidEnterpriseFeaturesEnabled
) {
return null;
}
const isSelected = user.role === role;
const icon = ROLE_ICONS[role] ?? SvgUser;
return (
<LineItem
key={role}
icon={isSelected ? SvgCheck : icon}
selected={isSelected}
emphasized={isSelected}
onClick={() => handleSelect(role)}
>
{USER_ROLE_LABELS[role]}
</LineItem>
);
})}
</div>
</Popover.Content>
</Popover>
</Disabled>
);
}

View File

@@ -2,40 +2,21 @@
import { useState } from "react";
import { Button } from "@opal/components";
import {
SvgMoreHorizontal,
SvgUsers,
SvgXCircle,
SvgTrash,
SvgCheck,
} from "@opal/icons";
import { SvgMoreHorizontal, SvgXCircle, SvgTrash, SvgCheck } from "@opal/icons";
import { Disabled } from "@opal/core";
import Popover from "@/refresh-components/Popover";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import Text from "@/refresh-components/texts/Text";
import { UserStatus } from "@/lib/types";
import { toast } from "@/hooks/useToast";
import {
deactivateUser,
activateUser,
deleteUser,
cancelInvite,
approveRequest,
} from "./svc";
import EditGroupsModal from "./EditGroupsModal";
import { deactivateUser, activateUser, deleteUser } from "./svc";
import type { UserRow } from "./interfaces";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type ModalType =
| "deactivate"
| "activate"
| "delete"
| "cancelInvite"
| "editGroups"
| null;
type ModalType = "deactivate" | "activate" | "delete" | null;
interface UserRowActionsProps {
user: UserRow;
@@ -71,101 +52,14 @@ export default function UserRowActions({
}
}
const openModal = (type: ModalType) => {
setPopoverOpen(false);
setModal(type);
};
// Status-aware action menus
const actionButtons = (() => {
switch (user.status) {
case UserStatus.INVITED:
return (
<Button
prominence="tertiary"
variant="danger"
icon={SvgXCircle}
onClick={() => openModal("cancelInvite")}
>
Cancel Invite
</Button>
);
case UserStatus.REQUESTED:
return (
<Button
prominence="tertiary"
icon={SvgCheck}
onClick={() => {
setPopoverOpen(false);
handleAction(
() => approveRequest(user.email),
"Request approved"
);
}}
>
Approve
</Button>
);
case UserStatus.ACTIVE:
return (
<>
{user.id && (
<Button
prominence="tertiary"
icon={SvgUsers}
onClick={() => openModal("editGroups")}
>
Groups
</Button>
)}
<Button
prominence="tertiary"
icon={SvgXCircle}
onClick={() => openModal("deactivate")}
>
Deactivate User
</Button>
</>
);
case UserStatus.INACTIVE:
return (
<>
{user.id && (
<Button
prominence="tertiary"
icon={SvgUsers}
onClick={() => openModal("editGroups")}
>
Groups
</Button>
)}
<Button
prominence="tertiary"
icon={SvgCheck}
onClick={() => openModal("activate")}
>
Activate User
</Button>
<Button
prominence="tertiary"
variant="danger"
icon={SvgTrash}
onClick={() => openModal("delete")}
>
Delete User
</Button>
</>
);
default: {
const _exhaustive: never = user.status;
return null;
}
}
})();
// Only show actions for accepted users (active or inactive).
// Invited/requested users have no row actions in this PR.
if (
user.status !== UserStatus.ACTIVE &&
user.status !== UserStatus.INACTIVE
) {
return null;
}
// SCIM-managed users cannot be modified from the UI — changes would be
// overwritten on the next IdP sync.
@@ -180,47 +74,46 @@ export default function UserRowActions({
<Button prominence="tertiary" icon={SvgMoreHorizontal} />
</Popover.Trigger>
<Popover.Content align="end">
<div className="flex flex-col gap-0.5 p-1">{actionButtons}</div>
</Popover.Content>
</Popover>
{modal === "editGroups" && user.id && (
<EditGroupsModal
user={user as UserRow & { id: string }}
onClose={() => setModal(null)}
onMutate={onMutate}
/>
)}
{modal === "cancelInvite" && (
<ConfirmationModalLayout
icon={SvgXCircle}
title="Cancel Invite"
onClose={() => setModal(null)}
submit={
<Disabled disabled={isSubmitting}>
<div className="flex flex-col gap-0.5 p-1">
{user.status === UserStatus.ACTIVE ? (
<Button
variant="danger"
prominence="tertiary"
icon={SvgXCircle}
onClick={() => {
handleAction(
() => cancelInvite(user.email),
"Invite cancelled"
);
setPopoverOpen(false);
setModal("deactivate");
}}
>
Cancel
Deactivate User
</Button>
</Disabled>
}
>
<Text as="p" text03>
<Text as="span" text05>
{user.email}
</Text>{" "}
will no longer be able to join Onyx with this invite.
</Text>
</ConfirmationModalLayout>
)}
) : (
<>
<Button
prominence="tertiary"
icon={SvgCheck}
onClick={() => {
setPopoverOpen(false);
setModal("activate");
}}
>
Activate User
</Button>
<Button
prominence="tertiary"
variant="danger"
icon={SvgTrash}
onClick={() => {
setPopoverOpen(false);
setModal("delete");
}}
>
Delete User
</Button>
</>
)}
</div>
</Popover.Content>
</Popover>
{modal === "deactivate" && (
<ConfirmationModalLayout
@@ -248,8 +141,7 @@ export default function UserRowActions({
{user.email}
</Text>{" "}
will immediately lose access to Onyx. Their sessions and agents will
be preserved. Their license seat will be freed. You can reactivate
this account later.
be preserved. You can reactivate this account later.
</Text>
</ConfirmationModalLayout>
)}
@@ -309,7 +201,7 @@ export default function UserRowActions({
{user.email}
</Text>{" "}
will be permanently removed from Onyx. All of their session history
will be deleted. Deletion cannot be undone.
will be deleted. This cannot be undone.
</Text>
</ConfirmationModalLayout>
)}

View File

@@ -1,15 +1,14 @@
import { SvgArrowUpRight, SvgFilterPlus, SvgUserSync } from "@opal/icons";
import { SvgArrowUpRight, SvgFilter, SvgUserSync } from "@opal/icons";
import { ContentAction } from "@opal/layouts";
import { Button } from "@opal/components";
import { Section } from "@/layouts/general-layouts";
import Card from "@/refresh-components/cards/Card";
import IconButton from "@/refresh-components/buttons/IconButton";
import Text from "@/refresh-components/texts/Text";
import Link from "next/link";
import { ADMIN_PATHS } from "@/lib/admin-routes";
// ---------------------------------------------------------------------------
// Stats cell — number + label + hover filter icon
// Stats cell — number + label
// ---------------------------------------------------------------------------
type StatCellProps = {
@@ -35,18 +34,12 @@ function StatCell({ value, label, onFilter }: StatCellProps) {
{label}
</Text>
{onFilter && (
<IconButton
tertiary
icon={SvgFilterPlus}
tooltip="Add Filter"
toolTipPosition="left"
tooltipSize="sm"
className="absolute right-1 top-1 opacity-0 group-hover/stat:opacity-100 transition-opacity"
onClick={(e) => {
e.stopPropagation();
onFilter();
}}
/>
<div className="absolute right-2 top-2 flex items-center gap-1 opacity-0 group-hover/stat:opacity-100 transition-opacity">
<Text as="span" secondaryBody text03>
Filter
</Text>
<SvgFilter size={16} className="text-text-03" />
</div>
)}
</div>
);

View File

@@ -4,8 +4,6 @@ import { useMemo, useState } from "react";
import DataTable from "@/refresh-components/table/DataTable";
import { createTableColumns } from "@/refresh-components/table/columns";
import { Content } from "@opal/layouts";
import { Button } from "@opal/components";
import { SvgDownload } from "@opal/icons";
import SvgNoResult from "@opal/illustrations/no-result";
import { IllustrationContent } from "@opal/layouts";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
@@ -13,16 +11,14 @@ import { UserRole, UserStatus, USER_STATUS_LABELS } from "@/lib/types";
import { timeAgo } from "@/lib/time";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { toast } from "@/hooks/useToast";
import useAdminUsers from "@/hooks/useAdminUsers";
import useGroups from "@/hooks/useGroups";
import { downloadUsersCsv } from "./svc";
import UserFilters from "./UserFilters";
import GroupsCell from "./GroupsCell";
import UserRowActions from "./UserRowActions";
import UserRoleCell from "./UserRoleCell";
import type {
UserRow,
UserGroupInfo,
GroupOption,
StatusFilter,
StatusCountMap,
@@ -44,6 +40,37 @@ function renderNameColumn(email: string, row: UserRow) {
);
}
function renderGroupsColumn(groups: UserGroupInfo[]) {
if (!groups.length) {
return (
<Text as="span" secondaryBody text03>
{"\u2014"}
</Text>
);
}
const visible = groups.slice(0, 2);
const overflow = groups.length - visible.length;
return (
<div className="flex items-center gap-1 flex-nowrap overflow-hidden min-w-0">
{visible.map((g) => (
<span
key={g.id}
className="inline-flex items-center flex-shrink-0 rounded-md bg-background-tint-02 px-2 py-0.5 whitespace-nowrap"
>
<Text as="span" secondaryBody text03>
{g.name}
</Text>
</span>
))}
{overflow > 0 && (
<Text as="span" secondaryBody text03>
+{overflow}
</Text>
)}
</div>
);
}
function renderStatusColumn(value: UserStatus, row: UserRow) {
return (
<div className="flex flex-col">
@@ -62,7 +89,7 @@ function renderStatusColumn(value: UserStatus, row: UserRow) {
function renderLastUpdatedColumn(value: string | null) {
return (
<Text as="span" secondaryBody text03>
{value ? timeAgo(value) ?? "\u2014" : "\u2014"}
{timeAgo(value) ?? "\u2014"}
</Text>
);
}
@@ -91,9 +118,7 @@ function buildColumns(onMutate: () => void) {
weight: 24,
minWidth: 200,
enableSorting: false,
cell: (value, row) => (
<GroupsCell groups={value} user={row} onMutate={onMutate} />
),
cell: renderGroupsColumn,
}),
tc.column("role", {
header: "Account Type",
@@ -216,40 +241,22 @@ export default function UsersTable({
roleCounts={roleCounts}
statusCounts={statusCounts}
/>
<DataTable
data={filteredUsers}
columns={columns}
getRowId={(row) => row.id ?? row.email}
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
emptyState={
<IllustrationContent
illustration={SvgNoResult}
title="No users found"
description="No users match the current filters."
/>
}
footer={{
mode: "summary",
leftExtra: (
<Button
icon={SvgDownload}
prominence="tertiary"
size="sm"
tooltip="Download CSV"
onClick={() => {
downloadUsersCsv().catch((err) => {
toast.error(
err instanceof Error
? err.message
: "Failed to download CSV"
);
});
}}
/>
),
}}
/>
{filteredUsers.length === 0 ? (
<IllustrationContent
illustration={SvgNoResult}
title="No users found"
description="No users match the current filters."
/>
) : (
<DataTable
data={filteredUsers}
columns={columns}
getRowId={(row) => row.id ?? row.email}
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
footer={{ mode: "summary" }}
/>
)}
</div>
);
}

View File

@@ -59,63 +59,6 @@ export async function setUserRole(
}
}
export async function addUserToGroup(
groupId: number,
userId: string
): Promise<void> {
const res = await fetch(`/api/manage/admin/user-group/${groupId}/add-users`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ user_ids: [userId] }),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to add user to group"));
}
}
export async function removeUserFromGroup(
groupId: number,
currentUserIds: string[],
userIdToRemove: string,
ccPairIds: number[]
): Promise<void> {
const res = await fetch(`/api/manage/admin/user-group/${groupId}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
user_ids: currentUserIds.filter((id) => id !== userIdToRemove),
cc_pair_ids: ccPairIds,
}),
});
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to remove user from group")
);
}
}
export async function cancelInvite(email: string): Promise<void> {
const res = await fetch("/api/manage/admin/remove-invited-user", {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ user_email: email }),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to cancel invite"));
}
}
export async function approveRequest(email: string): Promise<void> {
const res = await fetch("/api/tenants/users/invite/approve", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ email }),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to approve request"));
}
}
export async function inviteUsers(emails: string[]): Promise<void> {
const res = await fetch("/api/manage/admin/users", {
method: "PUT",
@@ -126,20 +69,3 @@ export async function inviteUsers(emails: string[]): Promise<void> {
throw new Error(await parseErrorDetail(res, "Failed to invite users"));
}
}
export async function downloadUsersCsv(): Promise<void> {
const res = await fetch("/api/manage/users/download");
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to download users CSV")
);
}
const blob = await res.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
const ts = new Date().toISOString().replace(/[:.]/g, "-").slice(0, 19);
a.download = `onyx_users_${ts}.csv`;
a.click();
URL.revokeObjectURL(url);
}

View File

@@ -121,6 +121,7 @@ const collections = (
{
name: "User Management",
items: [
sidebarItem(ADMIN_PATHS.USERS),
...(enableEnterprise ? [sidebarItem(ADMIN_PATHS.GROUPS)] : []),
sidebarItem(ADMIN_PATHS.API_KEYS),
sidebarItem(ADMIN_PATHS.TOKEN_RATE_LIMITS),
@@ -129,7 +130,8 @@ const collections = (
{
name: "Permissions",
items: [
sidebarItem(ADMIN_PATHS.USERS),
// TODO (nikolas): Uncommented in switchover PR once Users v2 is ready
// sidebarItem(ADMIN_PATHS.USERS_V2),
...(enableEnterprise ? [sidebarItem(ADMIN_PATHS.SCIM)] : []),
],
},

View File

@@ -1,248 +0,0 @@
/**
* Page Object Model for the Admin Users page (/admin/users).
*
* Encapsulates all locators and interactions so specs remain declarative.
*/
import { type Page, type Locator, expect } from "@playwright/test";
export class UsersAdminPage {
readonly page: Page;
// Top-level elements
readonly inviteButton: Locator;
readonly searchInput: Locator;
// Filter buttons
readonly accountTypesFilter: Locator;
readonly groupsFilter: Locator;
readonly statusFilter: Locator;
// Table
readonly table: Locator;
readonly tableRows: Locator;
// Pagination
readonly paginationSummary: Locator;
constructor(page: Page) {
this.page = page;
this.inviteButton = page.getByRole("button", { name: "Invite Users" });
this.searchInput = page.getByPlaceholder("Search users...");
this.accountTypesFilter = page.getByRole("button", {
name: /Account Types/,
});
this.groupsFilter = page.getByRole("button", { name: /Groups/ });
this.statusFilter = page.getByRole("button", { name: /Status/ });
this.table = page.getByRole("table");
this.tableRows = page.getByRole("table").locator("tbody tr");
this.paginationSummary = page.getByText(/Showing \d/);
}
// ---------------------------------------------------------------------------
// Navigation
// ---------------------------------------------------------------------------
async goto() {
await this.page.goto("/admin/users");
await expect(this.page.getByText("Users & Requests")).toBeVisible({
timeout: 15000,
});
}
// ---------------------------------------------------------------------------
// Search
// ---------------------------------------------------------------------------
async search(term: string) {
await this.searchInput.fill(term);
await this.page.waitForTimeout(300);
}
async clearSearch() {
await this.searchInput.fill("");
await this.page.waitForTimeout(300);
}
// ---------------------------------------------------------------------------
// Filters
// ---------------------------------------------------------------------------
async openAccountTypesFilter() {
await this.accountTypesFilter.click();
await expect(
this.page
.getByRole("dialog")
.or(this.page.locator("[data-radix-popper-content-wrapper]"))
).toBeVisible();
}
async selectAccountType(label: string) {
const popover = this.page.locator("[data-radix-popper-content-wrapper]");
await popover.getByRole("button", { name: new RegExp(label) }).click();
}
async openStatusFilter() {
await this.statusFilter.click();
await expect(
this.page.locator("[data-radix-popper-content-wrapper]")
).toBeVisible();
}
async selectStatus(label: string) {
const popover = this.page.locator("[data-radix-popper-content-wrapper]");
await popover.getByRole("button", { name: new RegExp(label) }).click();
}
async openGroupsFilter() {
await this.groupsFilter.click();
await expect(
this.page.locator("[data-radix-popper-content-wrapper]")
).toBeVisible();
}
async selectGroup(label: string) {
const popover = this.page.locator("[data-radix-popper-content-wrapper]");
await popover.getByRole("button", { name: new RegExp(label) }).click();
}
async closePopover() {
await this.page.keyboard.press("Escape");
await this.page.waitForTimeout(200);
}
// ---------------------------------------------------------------------------
// Table interactions
// ---------------------------------------------------------------------------
async getVisibleRowCount(): Promise<number> {
return await this.tableRows.count();
}
getRowByEmail(email: string): Locator {
return this.table.getByRole("row").filter({ hasText: email });
}
async sortByColumn(columnName: string) {
const header = this.table
.getByRole("columnheader")
.filter({ hasText: columnName });
await header.getByRole("button").first().click();
await this.page.waitForTimeout(300);
}
// ---------------------------------------------------------------------------
// Row actions
// ---------------------------------------------------------------------------
async openRowActions(email: string) {
const row = this.getRowByEmail(email);
const actionsButton = row.getByRole("button").last();
await actionsButton.click();
await expect(
this.page.locator("[data-radix-popper-content-wrapper]")
).toBeVisible();
}
async clickRowAction(actionName: string) {
const popover = this.page.locator("[data-radix-popper-content-wrapper]");
await popover.getByRole("button", { name: actionName }).click();
}
// ---------------------------------------------------------------------------
// Confirmation modals
// ---------------------------------------------------------------------------
get dialog(): Locator {
return this.page.getByRole("dialog");
}
async confirmModalAction(buttonName: string) {
await this.dialog.getByRole("button", { name: buttonName }).click();
}
async cancelModal() {
await this.dialog.getByRole("button", { name: "Cancel" }).click();
}
async expectToast(message: string | RegExp) {
await expect(this.page.getByText(message)).toBeVisible({ timeout: 10000 });
}
// ---------------------------------------------------------------------------
// Invite modal
// ---------------------------------------------------------------------------
async openInviteModal() {
await this.inviteButton.click();
await expect(this.dialog.getByText("Invite Users")).toBeVisible();
}
async addInviteEmail(email: string) {
const input = this.dialog.getByPlaceholder(
"Add emails to invite, comma separated"
);
await input.fill(email + ",");
await this.page.waitForTimeout(200);
}
async submitInvite() {
await this.dialog.getByRole("button", { name: "Invite" }).click();
}
// ---------------------------------------------------------------------------
// Inline role editing (Popover + OpenButton + LineItem)
// ---------------------------------------------------------------------------
async openRoleDropdown(email: string) {
const row = this.getRowByEmail(email);
// The role cell renders an OpenButton inside a Popover.Trigger
const roleButton = row
.locator("button")
.filter({ hasText: /Basic|Admin|Global Curator|Slack User/ });
await roleButton.click();
await expect(
this.page.locator("[data-radix-popper-content-wrapper]")
).toBeVisible();
}
async selectRole(roleName: string) {
const popover = this.page
.locator("[data-radix-popper-content-wrapper]")
.last();
await popover.getByRole("button", { name: roleName }).click();
await this.page.waitForTimeout(500);
}
// ---------------------------------------------------------------------------
// Edit groups modal
// ---------------------------------------------------------------------------
async openEditGroupsModal(email: string) {
await this.openRowActions(email);
await this.clickRowAction("Groups");
await expect(
this.dialog.getByText("Edit User's Groups & Roles")
).toBeVisible();
}
async searchGroupsInModal(term: string) {
await this.dialog.getByPlaceholder("Search groups to join...").fill(term);
await this.page.waitForTimeout(300);
}
async toggleGroupInModal(groupName: string) {
await this.dialog
.getByRole("button", { name: new RegExp(groupName) })
.first()
.click();
await this.page.waitForTimeout(200);
}
async saveGroupsModal() {
await this.dialog.getByRole("button", { name: "Save Changes" }).click();
}
}

View File

@@ -1,37 +0,0 @@
/**
* Playwright fixtures for Admin Users page tests.
*
* Provides:
* - Authenticated admin page
* - OnyxApiClient for API-level setup/teardown
* - UsersAdminPage page object
*/
import { test as base, expect, type Page } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { UsersAdminPage } from "./UsersAdminPage";
export const test = base.extend<{
adminPage: Page;
api: OnyxApiClient;
usersPage: UsersAdminPage;
}>({
adminPage: async ({ page }, use) => {
await page.context().clearCookies();
await loginAs(page, "admin");
await use(page);
},
api: async ({ adminPage }, use) => {
const client = new OnyxApiClient(adminPage.request);
await use(client);
},
usersPage: async ({ adminPage }, use) => {
const usersPage = new UsersAdminPage(adminPage);
await use(usersPage);
},
});
export { expect };

View File

@@ -1,754 +0,0 @@
/**
* E2E Tests: Admin Users Page
*
* Tests the full users management page — search, filters, sorting,
* inline role editing, row actions, invite modal, and group management.
*
* All tests create their own data via API and clean up after themselves.
*
* Tagged @exclusive because tests mutate user state and must run serially.
*/
import { test, expect } from "./fixtures";
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function uniqueEmail(prefix: string): string {
return `e2e-${prefix}-${Date.now()}@test.onyx`;
}
const TEST_PASSWORD = "TestPassword123!";
// ---------------------------------------------------------------------------
// Page load & layout
// ---------------------------------------------------------------------------
test.describe("Users page — layout @exclusive", () => {
test("renders page title, invite button, search, and stats bar", async ({
usersPage,
}) => {
await usersPage.goto();
await expect(usersPage.page.getByText("Users & Requests")).toBeVisible();
await expect(usersPage.inviteButton).toBeVisible();
await expect(usersPage.searchInput).toBeVisible();
await expect(usersPage.page.getByText(/active users/i)).toBeVisible();
});
test("table renders with correct column headers", async ({ usersPage }) => {
await usersPage.goto();
for (const header of [
"Name",
"Groups",
"Account Type",
"Status",
"Last Updated",
]) {
await expect(
usersPage.table.getByRole("columnheader", { name: header })
).toBeVisible();
}
});
test("pagination shows summary and controls", async ({ usersPage }) => {
await usersPage.goto();
await expect(usersPage.paginationSummary).toBeVisible();
await expect(usersPage.paginationSummary).toContainText("Showing");
});
test("CSV download button is visible in footer", async ({ usersPage }) => {
await usersPage.goto();
// The download button is an icon-only button with a tooltip
const downloadBtn = usersPage.page.getByRole("button", {
name: /Download CSV/i,
});
await expect(downloadBtn).toBeVisible();
});
});
// ---------------------------------------------------------------------------
// Search
// ---------------------------------------------------------------------------
test.describe("Users page — search @exclusive", () => {
let testEmail: string;
const personalName = `Zephyr${Date.now()}`;
test.beforeAll(async ({ browser }) => {
const adminCtx = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const adminApi = new OnyxApiClient(adminCtx.request);
testEmail = uniqueEmail("search");
await adminApi.registerUser(testEmail, TEST_PASSWORD);
// Log in as the new user to set their personal name
const userCtx = await browser.newContext();
try {
await userCtx.request.post(
`${process.env.BASE_URL || "http://localhost:3000"}/api/auth/login`,
{ form: { username: testEmail, password: TEST_PASSWORD } }
);
const userApi = new OnyxApiClient(userCtx.request);
await userApi.setPersonalName(personalName);
} finally {
await userCtx.close();
}
} finally {
await adminCtx.close();
}
});
test("search filters table rows by email", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.search(testEmail);
const row = usersPage.getRowByEmail(testEmail);
await expect(row).toBeVisible({ timeout: 10000 });
const rowCount = await usersPage.getVisibleRowCount();
expect(rowCount).toBeGreaterThanOrEqual(1);
expect(rowCount).toBeLessThanOrEqual(8);
});
test("search matches by personal name", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.search(personalName);
const row = usersPage.getRowByEmail(testEmail);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText(personalName);
});
test("search with no results shows empty state", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.search("zzz-no-match-exists-xyz@nowhere.invalid");
await expect(usersPage.page.getByText("No users found")).toBeVisible();
});
test("clearing search restores all results", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.search("zzz-no-match-exists-xyz@nowhere.invalid");
await expect(usersPage.page.getByText("No users found")).toBeVisible();
await usersPage.clearSearch();
await expect(usersPage.table).toBeVisible();
const rowCount = await usersPage.getVisibleRowCount();
expect(rowCount).toBeGreaterThan(0);
});
test.afterAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
await api.deactivateUser(testEmail).catch(() => {});
await api.deleteUser(testEmail).catch(() => {});
} finally {
await context.close();
}
});
});
// ---------------------------------------------------------------------------
// Filters
// ---------------------------------------------------------------------------
test.describe("Users page — filters @exclusive", () => {
let activeEmail: string;
let inactiveEmail: string;
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
activeEmail = uniqueEmail("filt-active");
await api.registerUser(activeEmail, TEST_PASSWORD);
inactiveEmail = uniqueEmail("filt-inactive");
await api.registerUser(inactiveEmail, TEST_PASSWORD);
await api.deactivateUser(inactiveEmail);
} finally {
await context.close();
}
});
test("account types filter shows expected roles", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.openAccountTypesFilter();
const popover = usersPage.page.locator(
"[data-radix-popper-content-wrapper]"
);
await expect(popover.getByText("All Account Types")).toBeVisible();
await expect(popover.getByText("Admin")).toBeVisible();
await expect(popover.getByText("Basic")).toBeVisible();
await usersPage.closePopover();
});
test("filtering by Admin role shows only admin users", async ({
usersPage,
}) => {
await usersPage.goto();
await usersPage.openAccountTypesFilter();
await usersPage.selectAccountType("Admin");
await usersPage.closePopover();
await expect(usersPage.accountTypesFilter).toContainText("Admin");
const rowCount = await usersPage.getVisibleRowCount();
expect(rowCount).toBeGreaterThan(0);
});
test("status filter for Active shows the active user", async ({
usersPage,
}) => {
await usersPage.goto();
await usersPage.openStatusFilter();
await usersPage.selectStatus("Active");
await usersPage.closePopover();
await expect(usersPage.statusFilter).toContainText("Active");
await usersPage.search(activeEmail);
const row = usersPage.getRowByEmail(activeEmail);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText("Active");
});
test("status filter for Inactive shows the inactive user", async ({
usersPage,
}) => {
await usersPage.goto();
await usersPage.openStatusFilter();
await usersPage.selectStatus("Inactive");
await usersPage.closePopover();
await expect(usersPage.statusFilter).toContainText("Inactive");
await usersPage.search(inactiveEmail);
const row = usersPage.getRowByEmail(inactiveEmail);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText("Inactive");
});
test("resetting filter shows all users again", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.openStatusFilter();
await usersPage.selectStatus("Active");
await usersPage.closePopover();
const filteredCount = await usersPage.getVisibleRowCount();
await usersPage.openStatusFilter();
await usersPage.selectStatus("All Status");
await usersPage.closePopover();
const allCount = await usersPage.getVisibleRowCount();
expect(allCount).toBeGreaterThanOrEqual(filteredCount);
});
test.afterAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
await api.deactivateUser(activeEmail).catch(() => {});
await api.deleteUser(activeEmail).catch(() => {});
await api.deleteUser(inactiveEmail).catch(() => {});
} finally {
await context.close();
}
});
});
// ---------------------------------------------------------------------------
// Sorting
// ---------------------------------------------------------------------------
test.describe("Users page — sorting @exclusive", () => {
test("clicking Name sort toggles row order", async ({ usersPage }) => {
await usersPage.goto();
const firstRowBefore = await usersPage.tableRows.first().textContent();
await usersPage.sortByColumn("Name");
const firstRowAfter = await usersPage.tableRows.first().textContent();
expect(firstRowBefore).toBeDefined();
expect(firstRowAfter).toBeDefined();
});
test("clicking Status sort keeps table rendered", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.sortByColumn("Status");
const rowCount = await usersPage.getVisibleRowCount();
expect(rowCount).toBeGreaterThan(0);
});
});
// ---------------------------------------------------------------------------
// Pagination
// ---------------------------------------------------------------------------
test.describe("Users page — pagination @exclusive", () => {
test("next/previous page buttons navigate between pages", async ({
usersPage,
}) => {
await usersPage.goto();
const summaryBefore = await usersPage.paginationSummary.textContent();
// Click next page if available
const nextButton = usersPage.page.getByRole("button", { name: /next/i });
if (await nextButton.isEnabled()) {
await nextButton.click();
await usersPage.page.waitForTimeout(300);
const summaryAfter = await usersPage.paginationSummary.textContent();
expect(summaryAfter).not.toBe(summaryBefore);
// Go back
const prevButton = usersPage.page.getByRole("button", {
name: /previous/i,
});
await prevButton.click();
await usersPage.page.waitForTimeout(300);
}
});
});
// ---------------------------------------------------------------------------
// Invite users
// ---------------------------------------------------------------------------
test.describe("Users page — invite users @exclusive", () => {
test("invite modal opens with correct structure", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.openInviteModal();
await expect(usersPage.dialog.getByText("Invite Users")).toBeVisible();
await expect(
usersPage.dialog.getByPlaceholder("Add emails to invite, comma separated")
).toBeVisible();
await expect(usersPage.dialog.getByText("User Role")).toBeVisible();
await usersPage.cancelModal();
await expect(usersPage.dialog).not.toBeVisible();
});
test("invite a user and verify Invite Pending status", async ({
usersPage,
api,
}) => {
const email = uniqueEmail("invite");
await usersPage.goto();
await usersPage.openInviteModal();
await usersPage.addInviteEmail(email);
await usersPage.submitInvite();
await usersPage.expectToast(/Invited 1 user/);
// Reload and search
await usersPage.goto();
await usersPage.search(email);
const row = usersPage.getRowByEmail(email);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText("Invite Pending");
// Cleanup
await api.cancelInvite(email);
});
test("invite multiple users at once", async ({ usersPage, api }) => {
const email1 = uniqueEmail("multi1");
const email2 = uniqueEmail("multi2");
await usersPage.goto();
await usersPage.openInviteModal();
const input = usersPage.dialog.getByPlaceholder(
"Add emails to invite, comma separated"
);
await input.fill(`${email1}, ${email2},`);
await usersPage.page.waitForTimeout(200);
await usersPage.submitInvite();
await usersPage.expectToast(/Invited 2 users/);
// Cleanup
await api.cancelInvite(email1);
await api.cancelInvite(email2);
});
test("invite modal shows error icon for invalid emails", async ({
usersPage,
}) => {
await usersPage.goto();
await usersPage.openInviteModal();
const input = usersPage.dialog.getByPlaceholder(
"Add emails to invite, comma separated"
);
await input.fill("not-an-email,");
await usersPage.page.waitForTimeout(200);
// The chip should be rendered with an error state
await expect(usersPage.dialog.getByText("not-an-email")).toBeVisible();
await usersPage.cancelModal();
});
});
// ---------------------------------------------------------------------------
// Row actions — deactivate / activate
// ---------------------------------------------------------------------------
test.describe("Users page — deactivate & activate @exclusive", () => {
let testUserEmail: string;
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
testUserEmail = uniqueEmail("deact");
await api.registerUser(testUserEmail, TEST_PASSWORD);
} finally {
await context.close();
}
});
test("deactivate and then reactivate a user", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.search(testUserEmail);
const row = usersPage.getRowByEmail(testUserEmail);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText("Active");
// Deactivate
await usersPage.openRowActions(testUserEmail);
await usersPage.clickRowAction("Deactivate User");
await expect(usersPage.dialog.getByText("Deactivate User")).toBeVisible();
await expect(usersPage.dialog.getByText(testUserEmail)).toBeVisible();
await expect(
usersPage.dialog.getByText("will immediately lose access")
).toBeVisible();
await usersPage.confirmModalAction("Deactivate");
await usersPage.expectToast("User deactivated");
// Verify Inactive
await usersPage.page.waitForTimeout(500);
await usersPage.search(testUserEmail);
const inactiveRow = usersPage.getRowByEmail(testUserEmail);
await expect(inactiveRow).toContainText("Inactive");
// Reactivate
await usersPage.openRowActions(testUserEmail);
await usersPage.clickRowAction("Activate User");
await expect(usersPage.dialog.getByText("Activate User")).toBeVisible();
await usersPage.confirmModalAction("Activate");
await usersPage.expectToast("User activated");
// Verify Active again
await usersPage.page.waitForTimeout(500);
await usersPage.search(testUserEmail);
const reactivatedRow = usersPage.getRowByEmail(testUserEmail);
await expect(reactivatedRow).toContainText("Active");
});
test.afterAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
await api.deactivateUser(testUserEmail).catch(() => {});
await api.deleteUser(testUserEmail).catch(() => {});
} finally {
await context.close();
}
});
});
// ---------------------------------------------------------------------------
// Row actions — delete user
// ---------------------------------------------------------------------------
test.describe("Users page — delete user @exclusive", () => {
test("delete an inactive user", async ({ usersPage, api }) => {
const email = uniqueEmail("delete");
await api.registerUser(email, TEST_PASSWORD);
await api.deactivateUser(email);
await usersPage.goto();
await usersPage.search(email);
const row = usersPage.getRowByEmail(email);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText("Inactive");
await usersPage.openRowActions(email);
await usersPage.clickRowAction("Delete User");
await expect(usersPage.dialog.getByText("Delete User")).toBeVisible();
await expect(
usersPage.dialog.getByText("will be permanently removed")
).toBeVisible();
await usersPage.confirmModalAction("Delete");
await usersPage.expectToast("User deleted");
// User gone
await usersPage.page.waitForTimeout(500);
await usersPage.search(email);
await expect(usersPage.page.getByText("No users found")).toBeVisible({
timeout: 10000,
});
});
});
// ---------------------------------------------------------------------------
// Row actions — cancel invite
// ---------------------------------------------------------------------------
test.describe("Users page — cancel invite @exclusive", () => {
test("cancel a pending invite", async ({ usersPage, api }) => {
const email = uniqueEmail("cancel-inv");
await api.inviteUsers([email]);
await usersPage.goto();
await usersPage.search(email);
const row = usersPage.getRowByEmail(email);
await expect(row).toBeVisible({ timeout: 10000 });
await expect(row).toContainText("Invite Pending");
await usersPage.openRowActions(email);
await usersPage.clickRowAction("Cancel Invite");
await expect(usersPage.dialog.getByText("Cancel Invite")).toBeVisible();
await usersPage.confirmModalAction("Cancel");
await usersPage.expectToast("Invite cancelled");
// User gone
await usersPage.page.waitForTimeout(500);
await usersPage.search(email);
await expect(usersPage.page.getByText("No users found")).toBeVisible({
timeout: 10000,
});
});
});
// ---------------------------------------------------------------------------
// Inline role editing
// ---------------------------------------------------------------------------
test.describe("Users page — inline role editing @exclusive", () => {
let testUserEmail: string;
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
testUserEmail = uniqueEmail("role");
await api.registerUser(testUserEmail, TEST_PASSWORD);
} finally {
await context.close();
}
});
test("change user role from Basic to Admin and back", async ({
usersPage,
}) => {
await usersPage.goto();
await usersPage.search(testUserEmail);
const row = usersPage.getRowByEmail(testUserEmail);
await expect(row).toBeVisible({ timeout: 10000 });
// Initially Basic — the OpenButton shows the role label
await expect(row.getByText("Basic")).toBeVisible();
// Change to Admin
await usersPage.openRoleDropdown(testUserEmail);
await usersPage.selectRole("Admin");
await usersPage.page.waitForTimeout(500);
await expect(row.getByText("Admin")).toBeVisible();
// Change back to Basic
await usersPage.openRoleDropdown(testUserEmail);
await usersPage.selectRole("Basic");
await usersPage.page.waitForTimeout(500);
await expect(row.getByText("Basic")).toBeVisible();
});
test.afterAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
await api.deactivateUser(testUserEmail).catch(() => {});
await api.deleteUser(testUserEmail).catch(() => {});
} finally {
await context.close();
}
});
});
// ---------------------------------------------------------------------------
// Group management
// ---------------------------------------------------------------------------
test.describe("Users page — group management @exclusive", () => {
let testUserEmail: string;
let testGroupId: number;
const groupName = `E2E-UsersTest-${Date.now()}`;
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
testUserEmail = uniqueEmail("grp");
await api.registerUser(testUserEmail, TEST_PASSWORD);
testGroupId = await api.createUserGroup(groupName);
} finally {
await context.close();
}
});
test("add user to group via edit groups modal", async ({ usersPage }) => {
await usersPage.goto();
await usersPage.search(testUserEmail);
const row = usersPage.getRowByEmail(testUserEmail);
await expect(row).toBeVisible({ timeout: 10000 });
await usersPage.openEditGroupsModal(testUserEmail);
await usersPage.searchGroupsInModal(groupName);
await usersPage.toggleGroupInModal(groupName);
await usersPage.saveGroupsModal();
await usersPage.expectToast("User updated");
// Verify group shows in the row
await usersPage.page.waitForTimeout(500);
await usersPage.search(testUserEmail);
const rowWithGroup = usersPage.getRowByEmail(testUserEmail);
await expect(rowWithGroup).toContainText(groupName);
});
test("remove user from group via edit groups modal", async ({
usersPage,
}) => {
await usersPage.goto();
await usersPage.search(testUserEmail);
const row = usersPage.getRowByEmail(testUserEmail);
await expect(row).toBeVisible({ timeout: 10000 });
await usersPage.openEditGroupsModal(testUserEmail);
// Group shows as joined — click to remove
await usersPage.toggleGroupInModal(groupName);
await usersPage.saveGroupsModal();
await usersPage.expectToast("User updated");
// Verify group removed
await usersPage.page.waitForTimeout(500);
await usersPage.search(testUserEmail);
await expect(usersPage.getRowByEmail(testUserEmail)).not.toContainText(
groupName
);
});
test.afterAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
await api.deleteUserGroup(testGroupId).catch(() => {});
await api.deactivateUser(testUserEmail).catch(() => {});
await api.deleteUser(testUserEmail).catch(() => {});
} finally {
await context.close();
}
});
});
// ---------------------------------------------------------------------------
// Stats bar
// ---------------------------------------------------------------------------
test.describe("Users page — stats bar @exclusive", () => {
test("stats bar shows active users count", async ({ usersPage }) => {
await usersPage.goto();
await expect(usersPage.page.getByText(/\d+ active users/i)).toBeVisible();
});
test("stats bar updates after inviting a user", async ({
usersPage,
api,
}) => {
const email = uniqueEmail("stats");
// Get initial pending count text
await usersPage.goto();
await usersPage.openInviteModal();
await usersPage.addInviteEmail(email);
await usersPage.submitInvite();
await usersPage.expectToast(/Invited 1 user/);
// Stats bar should reflect the new invite
await usersPage.goto();
await expect(usersPage.page.getByText(/pending invites/i)).toBeVisible();
// Cleanup
await api.cancelInvite(email);
});
});

View File

@@ -1073,62 +1073,6 @@ export class OnyxApiClient {
);
}
// === User Management Methods ===
async deactivateUser(email: string): Promise<void> {
const response = await this.request.patch(
`${this.baseUrl}/manage/admin/deactivate-user`,
{ data: { user_email: email } }
);
await this.handleResponse(response, `Failed to deactivate user ${email}`);
this.log(`Deactivated user: ${email}`);
}
async activateUser(email: string): Promise<void> {
const response = await this.request.patch(
`${this.baseUrl}/manage/admin/activate-user`,
{ data: { user_email: email } }
);
await this.handleResponse(response, `Failed to activate user ${email}`);
this.log(`Activated user: ${email}`);
}
async deleteUser(email: string): Promise<void> {
const response = await this.request.delete(
`${this.baseUrl}/manage/admin/delete-user`,
{ data: { user_email: email } }
);
await this.handleResponse(response, `Failed to delete user ${email}`);
this.log(`Deleted user: ${email}`);
}
async cancelInvite(email: string): Promise<void> {
const response = await this.request.patch(
`${this.baseUrl}/manage/admin/remove-invited-user`,
{ data: { user_email: email } }
);
await this.handleResponse(response, `Failed to cancel invite for ${email}`);
this.log(`Cancelled invite for: ${email}`);
}
async inviteUsers(emails: string[]): Promise<void> {
const response = await this.put("/manage/admin/users", { emails });
await this.handleResponse(response, `Failed to invite users`);
this.log(`Invited users: ${emails.join(", ")}`);
}
async setPersonalName(name: string): Promise<void> {
const response = await this.request.patch(
`${this.baseUrl}/user/personalization`,
{ data: { name } }
);
await this.handleResponse(
response,
`Failed to set personal name to ${name}`
);
this.log(`Set personal name: ${name}`);
}
// === Chat Session Methods ===
/**