mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 11:45:47 +00:00
Compare commits
17 Commits
nik/eng-36
...
ci_python_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ab44a2c66 | ||
|
|
65b74b974b | ||
|
|
784a99e24a | ||
|
|
da1f5a11f4 | ||
|
|
5633805890 | ||
|
|
0817b45ae1 | ||
|
|
af0e4bdebc | ||
|
|
4cd2320732 | ||
|
|
90a361f0e1 | ||
|
|
194efde97b | ||
|
|
d922a42262 | ||
|
|
f00c3a486e | ||
|
|
192080c9e4 | ||
|
|
c5787dc073 | ||
|
|
d424d6462c | ||
|
|
ecea86deb6 | ||
|
|
a5c1f50a8a |
@@ -11,6 +11,11 @@ permissions:
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -75,10 +80,82 @@ jobs:
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
|
||||
reason="command-failed"
|
||||
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
|
||||
reason="merge-conflict"
|
||||
fi
|
||||
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "details<<EOF"
|
||||
tail -n 40 "$output_file"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
@@ -116,7 +116,6 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
@@ -34,7 +34,6 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
@@ -129,19 +128,12 @@ class ScimDAL(DAL):
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
@@ -319,14 +311,8 @@ class ScimDAL(DAL):
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
@@ -334,18 +320,11 @@ class ScimDAL(DAL):
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
@@ -26,14 +26,14 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
@@ -41,8 +41,6 @@ from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -50,28 +48,15 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
@@ -101,39 +86,15 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,45 +102,15 @@ def get_schemas() -> ScimJSONResponse:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return ScimJSONResponse(
|
||||
return JSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -193,7 +124,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
@@ -213,63 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
return parts or name.formatted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -280,13 +158,12 @@ def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -308,66 +185,42 @@ def list_users(
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
response = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
list_resp = ScimListResponse(
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
return _scim_resource_response(list_resp)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
@@ -375,7 +228,7 @@ def create_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -417,25 +270,13 @@ def create_user(
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -445,13 +286,13 @@ def replace_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -472,24 +313,15 @@ def replace_user(
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
|
||||
|
||||
@@ -500,7 +332,7 @@ def patch_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
@@ -510,25 +342,23 @@ def patch_user(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched, ent_data = apply_user_patch(
|
||||
patched = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
@@ -563,35 +393,17 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails) if patched.emails else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
|
||||
|
||||
@@ -600,29 +412,25 @@ def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -634,7 +442,7 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
@@ -689,13 +497,12 @@ def _validate_and_parse_members(
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -715,59 +522,37 @@ def list_groups(
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
response = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
resource = provider.build_group_resource(
|
||||
return provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
@@ -775,7 +560,7 @@ def create_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -811,10 +596,7 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -824,13 +606,13 @@ def replace_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -845,9 +627,7 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -857,7 +637,7 @@ def patch_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
@@ -866,7 +646,7 @@ def patch_group(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -905,9 +685,7 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
@@ -915,13 +693,13 @@ def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
|
||||
@@ -7,14 +7,12 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -33,9 +31,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -75,36 +70,6 @@ class ScimUserGroupRef(BaseModel):
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -113,8 +78,6 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
@@ -125,10 +88,6 @@ class ScimUserResource(BaseModel):
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
@@ -206,19 +165,6 @@ class ScimPatchOperation(BaseModel):
|
||||
path: str | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_op(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
@@ -15,10 +15,7 @@ responsible for persisting changes.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
@@ -27,51 +24,6 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths: emails[primary eq true].value
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[primary\s+eq\s+true\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -82,25 +34,18 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict
|
||||
name_data: dict
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
@@ -108,183 +53,79 @@ def apply_user_patch(
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ctx: _UserPatchCtx,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict, value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via emails[primary eq true].value."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported enterprise extension attribute '{suffix}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
@@ -394,14 +235,12 @@ def _set_group_field(
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
|
||||
@@ -2,20 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
@@ -24,17 +17,6 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
@@ -59,22 +41,12 @@ class ScimProvider(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
@@ -86,48 +58,27 @@ class ScimProvider(ABC):
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=name,
|
||||
name=self._build_scim_name(user),
|
||||
displayName=user.personal_name,
|
||||
emails=emails,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
@@ -147,24 +98,9 @@ class ScimProvider(ABC):
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
@@ -175,27 +111,6 @@ class ScimProvider(ABC):
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
@@ -23,4 +22,4 @@ class OktaProvider(ScimProvider):
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
|
||||
@@ -4,7 +4,6 @@ Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -21,9 +20,6 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -108,31 +104,6 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
|
||||
@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -156,36 +158,7 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
logger.warning(
|
||||
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
|
||||
)
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if cleaned_doc.title and "\x00" in cleaned_doc.title:
|
||||
logger.warning(
|
||||
f"NUL characters found in document title: {cleaned_doc.title}"
|
||||
)
|
||||
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
cleaned_batch.append(sanitize_document_for_postgres(doc))
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
@@ -602,10 +575,13 @@ def connector_document_extraction(
|
||||
|
||||
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
|
||||
if hierarchy_node_batch:
|
||||
hierarchy_node_batch_cleaned = (
|
||||
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=hierarchy_node_batch,
|
||||
nodes=hierarchy_node_batch_cleaned,
|
||||
source=db_connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
@@ -624,7 +600,7 @@ def connector_document_extraction(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
|
||||
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
21
backend/onyx/db/code_interpreter.py
Normal file
21
backend/onyx/db/code_interpreter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import CodeInterpreterServer
|
||||
|
||||
|
||||
def fetch_code_interpreter_server(
|
||||
db_session: Session,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
return server
|
||||
|
||||
|
||||
def update_code_interpreter_server_enabled(
|
||||
db_session: Session,
|
||||
enabled: bool,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
server.server_enabled = enabled
|
||||
db_session.commit()
|
||||
return server
|
||||
@@ -49,6 +49,7 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -228,6 +229,8 @@ def index_doc_batch_prepare(
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents = sanitize_documents_for_postgres(documents)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
|
||||
150
backend/onyx/indexing/postgres_sanitization.py
Normal file
150
backend/onyx/indexing/postgres_sanitization.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
return value.replace("\x00", "")
|
||||
|
||||
|
||||
def _sanitize_json_like(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _sanitize_string(value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_json_like(item) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_sanitize_json_like(item) for item in value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[Any, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
|
||||
return sanitized
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
return expert.model_copy(
|
||||
update={
|
||||
"display_name": (
|
||||
_sanitize_string(expert.display_name)
|
||||
if expert.display_name is not None
|
||||
else None
|
||||
),
|
||||
"first_name": (
|
||||
_sanitize_string(expert.first_name)
|
||||
if expert.first_name is not None
|
||||
else None
|
||||
),
|
||||
"middle_initial": (
|
||||
_sanitize_string(expert.middle_initial)
|
||||
if expert.middle_initial is not None
|
||||
else None
|
||||
),
|
||||
"last_name": (
|
||||
_sanitize_string(expert.last_name)
|
||||
if expert.last_name is not None
|
||||
else None
|
||||
),
|
||||
"email": (
|
||||
_sanitize_string(expert.email) if expert.email is not None else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
|
||||
return ExternalAccess(
|
||||
external_user_emails={
|
||||
_sanitize_string(email) for email in external_access.external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
_sanitize_string(group_id)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
},
|
||||
is_public=external_access.is_public,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
cleaned_doc = document.model_copy(deep=True)
|
||||
|
||||
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
|
||||
if cleaned_doc.title is not None:
|
||||
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
|
||||
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id
|
||||
)
|
||||
|
||||
cleaned_doc.metadata = {
|
||||
_sanitize_string(key): (
|
||||
[_sanitize_string(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else _sanitize_string(value)
|
||||
)
|
||||
for key, value in cleaned_doc.metadata.items()
|
||||
}
|
||||
|
||||
if cleaned_doc.doc_metadata is not None:
|
||||
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
|
||||
if cleaned_doc.primary_owners is not None:
|
||||
cleaned_doc.primary_owners = [
|
||||
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
|
||||
]
|
||||
if cleaned_doc.secondary_owners is not None:
|
||||
cleaned_doc.secondary_owners = [
|
||||
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
|
||||
]
|
||||
|
||||
if cleaned_doc.external_access is not None:
|
||||
cleaned_doc.external_access = _sanitize_external_access(
|
||||
cleaned_doc.external_access
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = _sanitize_string(section.link)
|
||||
if section.text is not None:
|
||||
section.text = _sanitize_string(section.text)
|
||||
if section.image_file_id is not None:
|
||||
section.image_file_id = _sanitize_string(section.image_file_id)
|
||||
|
||||
return cleaned_doc
|
||||
|
||||
|
||||
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
|
||||
return [sanitize_document_for_postgres(document) for document in documents]
|
||||
|
||||
|
||||
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
|
||||
cleaned_node = node.model_copy(deep=True)
|
||||
|
||||
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
|
||||
if cleaned_node.raw_parent_id is not None:
|
||||
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
|
||||
if cleaned_node.link is not None:
|
||||
cleaned_node.link = _sanitize_string(cleaned_node.link)
|
||||
|
||||
if cleaned_node.external_access is not None:
|
||||
cleaned_node.external_access = _sanitize_external_access(
|
||||
cleaned_node.external_access
|
||||
)
|
||||
|
||||
return cleaned_node
|
||||
|
||||
|
||||
def sanitize_hierarchy_nodes_for_postgres(
|
||||
nodes: list[HierarchyNode],
|
||||
) -> list[HierarchyNode]:
|
||||
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]
|
||||
@@ -97,6 +97,9 @@ from onyx.server.features.web_search.api import router as web_search_router
|
||||
from onyx.server.federated.api import router as federated_router
|
||||
from onyx.server.kg.api import admin_router as kg_admin_router
|
||||
from onyx.server.manage.administrative import router as admin_router
|
||||
from onyx.server.manage.code_interpreter.api import (
|
||||
admin_router as code_interpreter_admin_router,
|
||||
)
|
||||
from onyx.server.manage.discord_bot.api import router as discord_bot_router
|
||||
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from onyx.server.manage.embedding.api import basic_router as embedding_router
|
||||
@@ -421,6 +424,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, kg_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, code_interpreter_admin_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, image_generation_admin_router
|
||||
)
|
||||
|
||||
@@ -1,14 +1,68 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_citation_link_destinations(message: str) -> str:
|
||||
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
|
||||
if "[[" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _CITATION_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(message)
|
||||
normalized_message = _normalize_citation_link_destinations(message)
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
|
||||
Binary file not shown.
47
backend/onyx/server/manage/code_interpreter/api.py
Normal file
47
backend/onyx/server/manage/code_interpreter/api.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/code-interpreter")
|
||||
|
||||
|
||||
@admin_router.get("/health")
|
||||
def get_code_interpreter_health(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> CodeInterpreterServerHealth:
|
||||
try:
|
||||
client = CodeInterpreterClient()
|
||||
return CodeInterpreterServerHealth(healthy=client.health())
|
||||
except ValueError:
|
||||
return CodeInterpreterServerHealth(healthy=False)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def get_code_interpreter(
|
||||
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
|
||||
) -> CodeInterpreterServer:
|
||||
ci_server = fetch_code_interpreter_server(db_session)
|
||||
return CodeInterpreterServer(enabled=ci_server.server_enabled)
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
def update_code_interpreter(
|
||||
update: CodeInterpreterServer,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_code_interpreter_server_enabled(
|
||||
db_session=db_session,
|
||||
enabled=update.enabled,
|
||||
)
|
||||
9
backend/onyx/server/manage/code_interpreter/models.py
Normal file
9
backend/onyx/server/manage/code_interpreter/models.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterServer(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class CodeInterpreterServerHealth(BaseModel):
|
||||
healthy: bool
|
||||
@@ -98,6 +98,17 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -103,8 +104,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
is_available = bool(CODE_INTERPRETER_BASE_URL)
|
||||
return is_available
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.0
|
||||
onyx-devtools==0.6.1
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -3,8 +3,8 @@ set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -55,6 +55,10 @@ else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
|
||||
# Start the Code Interpreter container
|
||||
echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
|
||||
CODE_INTERPRETER_HEALTH_URL = f"{CODE_INTERPRETER_URL}/health"
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def test_get_code_interpreter_health_as_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Health endpoint should return a JSON object with a 'healthy' boolean."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_HEALTH_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "healthy" in data
|
||||
assert isinstance(data["healthy"], bool)
|
||||
|
||||
|
||||
def test_get_code_interpreter_status_as_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""GET endpoint should return a JSON object with an 'enabled' boolean."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "enabled" in data
|
||||
assert isinstance(data["enabled"], bool)
|
||||
|
||||
|
||||
def test_update_code_interpreter_disable_and_enable(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""PUT endpoint should update the enabled flag and persist across reads."""
|
||||
# Disable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": False},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify disabled
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled"] is False
|
||||
|
||||
# Re-enable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify enabled
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled"] is True
|
||||
|
||||
|
||||
def test_code_interpreter_endpoints_require_admin(
|
||||
basic_user: DATestUser,
|
||||
) -> None:
|
||||
"""All code interpreter endpoints should reject non-admin users."""
|
||||
health_response = requests.get(
|
||||
CODE_INTERPRETER_HEALTH_URL,
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert health_response.status_code == 403
|
||||
|
||||
get_response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert get_response.status_code == 403
|
||||
|
||||
put_response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert put_response.status_code == 403
|
||||
|
||||
|
||||
def test_python_tool_hidden_from_tool_list_when_disabled(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""When code interpreter is disabled, the Python tool should not appear
|
||||
in the GET /tool response (i.e. the frontend tool list)."""
|
||||
# Disable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": False},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Python tool should not be in the tool list
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
tool_names = [t.name for t in tools]
|
||||
assert PYTHON_TOOL_NAME not in tool_names
|
||||
|
||||
# Re-enable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Python tool should reappear
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
tool_names = [t.name for t in tools]
|
||||
assert PYTHON_TOOL_NAME in tool_names
|
||||
@@ -0,0 +1,322 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from onyx.configs import app_configs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import ToolName
|
||||
|
||||
|
||||
_ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
|
||||
class NightlyProviderConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
provider: str
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
|
||||
def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
value = os.environ.get(env_var)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
custom_config_json = os.environ.get(_ENV_CUSTOM_CONFIG_JSON, "").strip()
|
||||
if custom_config_json:
|
||||
parsed = json.loads(custom_config_json)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
|
||||
custom_config = {str(key): str(value) for key, value in parsed.items()}
|
||||
|
||||
if provider == "ollama_chat" and api_key and not custom_config:
|
||||
custom_config = {"OLLAMA_API_KEY": api_key}
|
||||
|
||||
return NightlyProviderConfig(
|
||||
provider=provider,
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
def _skip_or_fail(strict: bool, message: str) -> None:
|
||||
if strict:
|
||||
pytest.fail(message)
|
||||
pytest.skip(message)
|
||||
|
||||
|
||||
def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
if not config.provider:
|
||||
_skip_or_fail(strict=config.strict, message=f"{_ENV_PROVIDER} must be set")
|
||||
|
||||
if not config.model_names:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
config.api_base or _default_api_base_for_provider(config.provider)
|
||||
):
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
app_configs.INTEGRATION_TESTS_MODE is True
|
||||
), "Integration tests require INTEGRATION_TESTS_MODE=true."
|
||||
|
||||
|
||||
def _seed_connector_for_search_tool(admin_user: DATestUser) -> None:
|
||||
# SearchTool is only exposed when at least one non-default connector exists.
|
||||
CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def _get_internal_search_tool_id(admin_user: DATestUser) -> int:
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
for tool in tools:
|
||||
if tool.in_code_tool_id == SEARCH_TOOL_ID:
|
||||
return tool.id
|
||||
raise AssertionError("SearchTool must exist for this test")
|
||||
|
||||
|
||||
def _default_api_base_for_provider(provider: str) -> str | None:
|
||||
if provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
if provider == "ollama_chat":
|
||||
# host.docker.internal works when tests are running inside the integration test container.
|
||||
return "http://host.docker.internal:11434"
|
||||
return None
|
||||
|
||||
|
||||
def _create_provider_payload(
|
||||
provider: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
"name": provider_name,
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"personas": [],
|
||||
"model_configurations": [{"name": model_name, "is_visible": True}],
|
||||
"api_key_changed": bool(api_key),
|
||||
"custom_config_changed": bool(custom_config),
|
||||
}
|
||||
|
||||
|
||||
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
|
||||
list_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
list_response.raise_for_status()
|
||||
providers = list_response.json()
|
||||
|
||||
current_default = next(
|
||||
(provider for provider in providers if provider.get("is_default_provider")),
|
||||
None,
|
||||
)
|
||||
assert (
|
||||
current_default is not None
|
||||
), "Expected a default provider after setting provider as default"
|
||||
assert (
|
||||
current_default["id"] == provider_id
|
||||
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
|
||||
|
||||
|
||||
def _run_chat_assertions(
|
||||
admin_user: DATestUser,
|
||||
search_tool_id: int,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
last_error: str | None = None
|
||||
# Retry once to reduce transient nightly flakes due provider-side blips.
|
||||
for attempt in range(1, 3):
|
||||
chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=(
|
||||
"Use internal_search to search for 'nightly-provider-regression-sentinel', "
|
||||
"then summarize the result in one short sentence."
|
||||
),
|
||||
user_performing_action=admin_user,
|
||||
forced_tool_ids=[search_tool_id],
|
||||
)
|
||||
|
||||
if response.error is None:
|
||||
used_internal_search = any(
|
||||
used_tool.tool_name == ToolName.INTERNAL_SEARCH
|
||||
for used_tool in response.used_tools
|
||||
)
|
||||
debug_has_internal_search = any(
|
||||
debug_tool_call.tool_name == "internal_search"
|
||||
for debug_tool_call in response.tool_call_debug
|
||||
)
|
||||
has_answer = bool(response.full_message.strip())
|
||||
|
||||
if used_internal_search and debug_has_internal_search and has_answer:
|
||||
return
|
||||
|
||||
last_error = (
|
||||
f"attempt={attempt} provider={provider} model={model_name} "
|
||||
f"used_internal_search={used_internal_search} "
|
||||
f"debug_internal_search={debug_has_internal_search} "
|
||||
f"has_answer={has_answer} "
|
||||
f"tool_call_debug={response.tool_call_debug}"
|
||||
)
|
||||
else:
|
||||
last_error = (
|
||||
f"attempt={attempt} provider={provider} model={model_name} "
|
||||
f"stream_error={response.error.error}"
|
||||
)
|
||||
|
||||
time.sleep(attempt)
|
||||
|
||||
pytest.fail(f"Chat/tool-call assertions failed: {last_error}")
|
||||
|
||||
|
||||
def _create_and_test_provider_for_model(
|
||||
admin_user: DATestUser,
|
||||
config: NightlyProviderConfig,
|
||||
model_name: str,
|
||||
search_tool_id: int,
|
||||
) -> None:
|
||||
provider_name = f"nightly-{config.provider}-{uuid4().hex[:12]}"
|
||||
resolved_api_base = config.api_base or _default_api_base_for_provider(
|
||||
config.provider
|
||||
)
|
||||
|
||||
provider_payload = _create_provider_payload(
|
||||
provider=config.provider,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
test_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/test",
|
||||
headers=admin_user.headers,
|
||||
json=provider_payload,
|
||||
)
|
||||
assert test_response.status_code == 200, (
|
||||
f"Provider test endpoint failed for provider={config.provider} "
|
||||
f"model={model_name}: {test_response.status_code} {test_response.text}"
|
||||
)
|
||||
|
||||
create_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json=provider_payload,
|
||||
)
|
||||
assert create_response.status_code == 200, (
|
||||
f"Provider creation failed for provider={config.provider} "
|
||||
f"model={model_name}: {create_response.status_code} {create_response.text}"
|
||||
)
|
||||
provider_id = create_response.json()["id"]
|
||||
|
||||
try:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert set_default_response.status_code == 200, (
|
||||
f"Setting default provider failed for provider={config.provider} "
|
||||
f"model={model_name}: {set_default_response.status_code} "
|
||||
f"{set_default_response.text}"
|
||||
)
|
||||
|
||||
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
|
||||
_run_chat_assertions(
|
||||
admin_user=admin_user,
|
||||
search_tool_id=search_tool_id,
|
||||
provider=config.provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
finally:
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
|
||||
def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
"""Nightly regression test for provider setup + default selection + chat tool calls."""
|
||||
_assert_integration_mode_enabled()
|
||||
config = _load_provider_config()
|
||||
_validate_provider_config(config)
|
||||
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
for model_name in config.model_names:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
@@ -98,11 +98,6 @@ class TestScimDALUserMappings:
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
"department": None,
|
||||
"manager": None,
|
||||
"given_name": None,
|
||||
"family_name": None,
|
||||
"scim_emails_json": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
@@ -211,13 +211,16 @@ def test_openai_provider_rejects_reference_images_for_unsupported_model() -> Non
|
||||
)
|
||||
|
||||
|
||||
def test_openai_provider_rejects_multiple_reference_images_for_dalle2() -> None:
|
||||
def test_openai_provider_rejects_multiple_reference_images_for_dalle3() -> None:
|
||||
provider = OpenAIImageGenerationProvider(api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="does not support image edits with reference images",
|
||||
):
|
||||
provider.generate_image(
|
||||
prompt="edit this image",
|
||||
model="dall-e-2",
|
||||
model="dall-e-3",
|
||||
size="1024x1024",
|
||||
n=1,
|
||||
reference_images=[
|
||||
@@ -307,17 +310,20 @@ def test_azure_provider_rejects_reference_images_for_unsupported_model() -> None
|
||||
)
|
||||
|
||||
|
||||
def test_azure_provider_rejects_multiple_reference_images_for_dalle2() -> None:
|
||||
def test_azure_provider_rejects_multiple_reference_images_for_dalle3() -> None:
|
||||
provider = AzureImageGenerationProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://azure.example.com",
|
||||
api_version="2024-05-01-preview",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="does not support image edits with reference images",
|
||||
):
|
||||
provider.generate_image(
|
||||
prompt="edit this image",
|
||||
model="dall-e-2",
|
||||
model="dall-e-3",
|
||||
size="1024x1024",
|
||||
n=1,
|
||||
reference_images=[
|
||||
|
||||
159
backend/tests/unit/onyx/indexing/test_postgres_sanitization.py
Normal file
159
backend/tests/unit/onyx/indexing/test_postgres_sanitization.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.indexing import indexing_pipeline
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_node_for_postgres
|
||||
|
||||
|
||||
def test_sanitize_document_for_postgres_removes_nul_bytes() -> None:
|
||||
document = Document(
|
||||
id="doc\x00-id",
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="sem\x00-id",
|
||||
title="ti\x00tle",
|
||||
parent_hierarchy_raw_node_id="parent\x00-id",
|
||||
sections=[TextSection(link="lin\x00k", text="te\x00xt")],
|
||||
metadata={"ke\x00y": "va\x00lue", "list\x00key": ["a\x00", "b"]},
|
||||
doc_metadata={
|
||||
"j\x00son": {
|
||||
"in\x00ner": "va\x00l",
|
||||
"arr": ["x\x00", {"dee\x00p": "y\x00"}],
|
||||
}
|
||||
},
|
||||
primary_owners=[BasicExpertInfo(display_name="Ali\x00ce", email="a\x00@x.com")],
|
||||
secondary_owners=[BasicExpertInfo(first_name="Bo\x00b", last_name="Sm\x00ith")],
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={"user\x00@example.com"},
|
||||
external_user_group_ids={"gro\x00up-1"},
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
|
||||
sanitized = sanitize_document_for_postgres(document)
|
||||
|
||||
assert sanitized.id == "doc-id"
|
||||
assert sanitized.semantic_identifier == "sem-id"
|
||||
assert sanitized.title == "title"
|
||||
assert sanitized.parent_hierarchy_raw_node_id == "parent-id"
|
||||
assert sanitized.sections[0].link == "link"
|
||||
assert sanitized.sections[0].text == "text"
|
||||
assert sanitized.metadata == {"key": "value", "listkey": ["a", "b"]}
|
||||
assert sanitized.doc_metadata == {
|
||||
"json": {"inner": "val", "arr": ["x", {"deep": "y"}]}
|
||||
}
|
||||
assert sanitized.primary_owners is not None
|
||||
assert sanitized.primary_owners[0].display_name == "Alice"
|
||||
assert sanitized.primary_owners[0].email == "a@x.com"
|
||||
assert sanitized.secondary_owners is not None
|
||||
assert sanitized.secondary_owners[0].first_name == "Bob"
|
||||
assert sanitized.secondary_owners[0].last_name == "Smith"
|
||||
assert sanitized.external_access is not None
|
||||
assert sanitized.external_access.external_user_emails == {"user@example.com"}
|
||||
assert sanitized.external_access.external_user_group_ids == {"group-1"}
|
||||
|
||||
# Ensure original document is not mutated
|
||||
assert document.id == "doc\x00-id"
|
||||
assert document.metadata == {"ke\x00y": "va\x00lue", "list\x00key": ["a\x00", "b"]}
|
||||
|
||||
|
||||
def test_sanitize_hierarchy_node_for_postgres_removes_nul_bytes() -> None:
|
||||
node = HierarchyNode(
|
||||
raw_node_id="raw\x00-id",
|
||||
raw_parent_id="paren\x00t-id",
|
||||
display_name="fol\x00der",
|
||||
link="https://exa\x00mple.com",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={"a\x00@example.com"},
|
||||
external_user_group_ids={"g\x00-1"},
|
||||
is_public=True,
|
||||
),
|
||||
)
|
||||
|
||||
sanitized = sanitize_hierarchy_node_for_postgres(node)
|
||||
|
||||
assert sanitized.raw_node_id == "raw-id"
|
||||
assert sanitized.raw_parent_id == "parent-id"
|
||||
assert sanitized.display_name == "folder"
|
||||
assert sanitized.link == "https://example.com"
|
||||
assert sanitized.external_access is not None
|
||||
assert sanitized.external_access.external_user_emails == {"a@example.com"}
|
||||
assert sanitized.external_access.external_user_group_ids == {"g-1"}
|
||||
|
||||
|
||||
def test_index_doc_batch_prepare_sanitizes_before_db_ops(
|
||||
monkeypatch: MonkeyPatch,
|
||||
) -> None:
|
||||
document = Document(
|
||||
id="doc\x00id",
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="sem\x00id",
|
||||
sections=[TextSection(text="content", link="li\x00nk")],
|
||||
metadata={"ke\x00y": "va\x00lue"},
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _get_documents_by_ids(db_session: object, document_ids: list[str]) -> list:
|
||||
_ = db_session, document_ids
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
indexing_pipeline, "get_documents_by_ids", _get_documents_by_ids
|
||||
)
|
||||
|
||||
def _capture_upsert_documents_in_db(**kwargs: object) -> None:
|
||||
captured["upsert_documents"] = kwargs["documents"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
indexing_pipeline, "_upsert_documents_in_db", _capture_upsert_documents_in_db
|
||||
)
|
||||
|
||||
def _capture_doc_cc_pair(*args: object) -> None:
|
||||
captured["cc_pair_doc_ids"] = args[3]
|
||||
|
||||
monkeypatch.setattr(
|
||||
indexing_pipeline,
|
||||
"upsert_document_by_connector_credential_pair",
|
||||
_capture_doc_cc_pair,
|
||||
)
|
||||
|
||||
def _noop_link_hierarchy_nodes_to_documents(
|
||||
db_session: object,
|
||||
document_ids: list[str],
|
||||
source: DocumentSource,
|
||||
commit: bool,
|
||||
) -> int:
|
||||
_ = db_session, document_ids, source, commit
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
indexing_pipeline,
|
||||
"link_hierarchy_nodes_to_documents",
|
||||
_noop_link_hierarchy_nodes_to_documents,
|
||||
)
|
||||
|
||||
context = indexing_pipeline.index_doc_batch_prepare(
|
||||
documents=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(connector_id=1, credential_id=2),
|
||||
db_session=object(), # type: ignore[arg-type]
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
|
||||
assert context is not None
|
||||
assert context.updatable_docs[0].id == "docid"
|
||||
assert context.updatable_docs[0].semantic_identifier == "semid"
|
||||
assert context.updatable_docs[0].metadata == {"key": "value"}
|
||||
assert captured["cc_pair_doc_ids"] == ["docid"]
|
||||
|
||||
upsert_documents = captured["upsert_documents"]
|
||||
assert isinstance(upsert_documents, list)
|
||||
assert upsert_documents[0].id == "docid"
|
||||
52
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
52
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from onyx.onyxbot.slack.formatting import _normalize_citation_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
@@ -13,9 +12,7 @@ import pytest
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
@@ -118,11 +115,6 @@ def make_user_mapping(**kwargs: Any) -> MagicMock:
|
||||
mapping.external_id = kwargs.get("external_id", "ext-default")
|
||||
mapping.user_id = kwargs.get("user_id", uuid4())
|
||||
mapping.scim_username = kwargs.get("scim_username", None)
|
||||
mapping.department = kwargs.get("department", None)
|
||||
mapping.manager = kwargs.get("manager", None)
|
||||
mapping.given_name = kwargs.get("given_name", None)
|
||||
mapping.family_name = kwargs.get("family_name", None)
|
||||
mapping.scim_emails_json = kwargs.get("scim_emails_json", None)
|
||||
return mapping
|
||||
|
||||
|
||||
@@ -130,35 +122,3 @@ def assert_scim_error(result: object, expected_status: int) -> None:
|
||||
"""Assert *result* is a JSONResponse with the given status code."""
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == expected_status
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_scim_user(result: object, *, status: int = 200) -> ScimUserResource:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimUserResource."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == status
|
||||
return ScimUserResource.model_validate(json.loads(result.body))
|
||||
|
||||
|
||||
def parse_scim_group(result: object, *, status: int = 200) -> ScimGroupResource:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimGroupResource."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == status
|
||||
return ScimGroupResource.model_validate(json.loads(result.body))
|
||||
|
||||
|
||||
def parse_scim_list(result: object) -> ScimListResponse:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimListResponse."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == 200
|
||||
return ScimListResponse.model_validate(json.loads(result.body))
|
||||
|
||||
@@ -1,983 +0,0 @@
|
||||
"""Comprehensive Entra ID (Azure AD) SCIM compatibility tests.
|
||||
|
||||
Covers the full Entra provisioning lifecycle: service discovery, user CRUD
|
||||
with enterprise extension schema, group CRUD with excludedAttributes, and
|
||||
all Entra-specific behavioral quirks (PascalCase ops, enterprise URN in
|
||||
PATCH value dicts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_group
|
||||
from ee.onyx.server.scim.api import get_resource_types
|
||||
from ee.onyx.server.scim.api import get_schemas
|
||||
from ee.onyx.server.scim.api import get_service_provider_config
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_groups
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entra_provider() -> ScimProvider:
|
||||
"""An EntraProvider instance for Entra-specific endpoint tests."""
|
||||
return EntraProvider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraServiceDiscovery:
|
||||
"""Entra expects enterprise extension in discovery endpoints."""
|
||||
|
||||
def test_service_provider_config_advertises_patch(self) -> None:
|
||||
config = get_service_provider_config()
|
||||
assert config.patch.supported is True
|
||||
|
||||
def test_resource_types_include_enterprise_extension(self) -> None:
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "Resources" in parsed
|
||||
user_type = next(rt for rt in parsed["Resources"] if rt["id"] == "User")
|
||||
extension_schemas = [ext["schema"] for ext in user_type["schemaExtensions"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in extension_schemas
|
||||
|
||||
def test_schemas_include_enterprise_user(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
schema_ids = [s["id"] for s in parsed["Resources"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in schema_ids
|
||||
|
||||
def test_enterprise_schema_has_expected_attributes(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
enterprise = next(
|
||||
s for s in parsed["Resources"] if s["id"] == SCIM_ENTERPRISE_USER_SCHEMA
|
||||
)
|
||||
attr_names = {a["name"] for a in enterprise["attributes"]}
|
||||
assert "department" in attr_names
|
||||
assert "manager" in attr_names
|
||||
|
||||
def test_service_discovery_content_type(self) -> None:
|
||||
"""SCIM responses must use application/scim+json content type."""
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.media_type == "application/scim+json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraUserLifecycle:
|
||||
"""Test user CRUD through Entra's lens: enterprise schemas, PascalCase ops."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="alice@contoso.com")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
assert SCIM_USER_SCHEMA in resource.schemas
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Enterprise extension department/manager should round-trip on create."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="Engineering",
|
||||
manager=ScimManagerRef(value="mgr-uuid-123"),
|
||||
),
|
||||
)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Engineering"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-uuid-123"
|
||||
|
||||
# Verify DAL received the enterprise fields
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
call_kwargs = mock_dal.create_user_mapping.call_args[1]
|
||||
assert call_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
manager="mgr-uuid-123",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_get_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_get_user_returns_enterprise_extension_data(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""GET should return stored enterprise extension data."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.department = "Sales"
|
||||
mapping.manager = "mgr-456"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Sales"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-456"
|
||||
|
||||
def test_list_users_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mapping = make_user_mapping(external_id="entra-ext-1", user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_patch_user_deactivate_with_pascal_case_replace(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Replace"`` (PascalCase) instead of ``"replace"``."""
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Replace", # type: ignore[arg-type]
|
||||
path="active",
|
||||
value=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Mock doesn't propagate the change, so verify via the DAL call
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
|
||||
def test_patch_user_add_external_id_with_pascal_case(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) instead of ``"add"``."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="externalId",
|
||||
value="entra-ext-999",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify the patched externalId was synced to the DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] == "entra-ext-999"
|
||||
|
||||
def test_patch_user_enterprise_extension_in_value_dict(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends enterprise extension URN as key in path-less PATCH value
|
||||
dicts — enterprise data should be stored, not ignored."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path=None,
|
||||
value=value,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify active=False was applied
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
# Verify enterprise data was passed to DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_patch_user_remove_external_id(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH remove op should clear the target field."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.external_id = "ext-to-remove"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REMOVE,
|
||||
path="externalId",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# externalId should be cleared (None)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] is None
|
||||
|
||||
def test_patch_user_emails_primary_eq_true_value(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with path emails[primary eq true].value should update
|
||||
the primary email entry, not userName."""
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="emails[primary eq true].value",
|
||||
value="new@contoso.com",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert resource.userName == "old@contoso.com"
|
||||
# Primary email should be updated
|
||||
primary_emails = [e for e in resource.emails if e.primary]
|
||||
assert len(primary_emails) == 1
|
||||
assert primary_emails[0].value == "new@contoso.com"
|
||||
|
||||
def test_patch_user_enterprise_urn_department_path(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with dotted enterprise URN path should store department."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
value="Marketing",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Marketing",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_replace_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="new@contoso.com",
|
||||
name=ScimName(givenName="New", familyName="Name"),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_replace_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PUT with enterprise extension should store the fields."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="HR",
|
||||
manager=ScimManagerRef(value="boss-id"),
|
||||
),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="HR",
|
||||
manager="boss-id",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_delete_user_returns_204(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = MagicMock(id=1)
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
def test_double_delete_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""Second DELETE should return 404 — the SCIM mapping is gone."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
# No mapping — user was already deleted from SCIM's perspective
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = None
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.status_code == 404
|
||||
|
||||
def test_name_formatted_preserved_on_create(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""When name.formatted is provided, it should be used as personal_name."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
name=ScimName(
|
||||
givenName="Alice",
|
||||
familyName="Smith",
|
||||
formatted="Dr. Alice Smith",
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api._check_seat_availability", return_value=None
|
||||
):
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result, status=201)
|
||||
# The User constructor should have received the formatted name
|
||||
mock_dal.add_user.assert_called_once()
|
||||
created_user = mock_dal.add_user.call_args[0][0]
|
||||
assert created_user.personal_name == "Dr. Alice Smith"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraGroupLifecycle:
|
||||
"""Test group CRUD with Entra-specific behaviors."""
|
||||
|
||||
def test_get_group_standard_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=10, name="Contoso Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
uid = uuid4()
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Contoso Engineering"
|
||||
assert len(resource.members) == 1
|
||||
|
||||
def test_list_groups_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on group list queries."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert parsed["totalResults"] == 1
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert resource["displayName"] == "Engineering"
|
||||
|
||||
def test_get_group_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on single group GET."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert parsed["displayName"] == "Engineering"
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_add_members_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) for group member additions."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
mock_dal.validate_member_ids.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(
|
||||
id="10",
|
||||
displayName="Engineering",
|
||||
members=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
mock_apply.return_value = (patched, [uid], [])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="members",
|
||||
value=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_remove_member_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Remove"`` (PascalCase) for group member removals."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(id="10", displayName="Engineering", members=[])
|
||||
mock_apply.return_value = (patched, [], [uid])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Remove", # type: ignore[arg-type]
|
||||
path=f'members[value eq "{uid}"]',
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# excludedAttributes (RFC 7644 §3.4.2.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExcludedAttributes:
|
||||
"""Test excludedAttributes query parameter on GET endpoints."""
|
||||
|
||||
def test_list_groups_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, None)], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert "displayName" in resource
|
||||
|
||||
def test_get_group_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_list_users_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
mock_dal.get_users_groups_batch.return_value = {user.id: [(1, "Engineering")]}
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
excludedAttributes="groups",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "groups" not in resource
|
||||
assert "userName" in resource
|
||||
|
||||
def test_get_user_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_groups.return_value = [(1, "Engineering")]
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
excludedAttributes="groups",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "groups" not in parsed
|
||||
assert "userName" in parsed
|
||||
|
||||
def test_multiple_excluded_attributes(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members,externalId",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "externalId" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_no_excluded_attributes_returns_full_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert len(resource.members) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entra Connection Probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraConnectionProbe:
|
||||
"""Entra sends a probe request during initial SCIM setup."""
|
||||
|
||||
def test_filter_for_nonexistent_user_returns_empty_list(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra probes with: GET /Users?filter=userName eq "non-existent"&count=1"""
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
result = list_users(
|
||||
filter='userName eq "non-existent@contoso.com"',
|
||||
startIndex=1,
|
||||
count=1,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
@@ -16,6 +16,7 @@ from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import replace_group
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
@@ -24,8 +25,6 @@ from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
|
||||
|
||||
class TestListGroups:
|
||||
@@ -49,9 +48,9 @@ class TestListGroups:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
|
||||
def test_unsupported_filter_returns_400(
|
||||
self,
|
||||
@@ -96,9 +95,9 @@ class TestListGroups:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 1
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimGroupResource)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.externalId == "ext-g-1"
|
||||
@@ -127,9 +126,9 @@ class TestGetGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.id == "5"
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "Engineering"
|
||||
assert result.id == "5"
|
||||
|
||||
def test_non_integer_id_returns_404(
|
||||
self,
|
||||
@@ -191,8 +190,8 @@ class TestCreateGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result, status=201)
|
||||
assert resource.displayName == "New Group"
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "New Group"
|
||||
mock_dal.add_group.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -284,7 +283,7 @@ class TestCreateGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result, status=201)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.create_group_mapping.assert_called_once()
|
||||
|
||||
|
||||
@@ -315,7 +314,7 @@ class TestReplaceGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
mock_dal.replace_group_members.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
@@ -428,7 +427,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
|
||||
def test_not_found_returns_404(
|
||||
@@ -535,7 +534,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.validate_member_ids.assert_called_once()
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@@ -615,7 +614,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
@@ -13,11 +12,9 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
|
||||
_ENTRA_IGNORED = EntraProvider().ignored_patch_paths
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
@@ -59,36 +56,36 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_deactivate_user(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("active", False)], user)
|
||||
result = apply_user_patch([_replace_op("active", False)], user)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_activate_user(self) -> None:
|
||||
user = _make_user(active=False)
|
||||
result, _ = apply_user_patch([_replace_op("active", True)], user)
|
||||
result = apply_user_patch([_replace_op("active", True)], user)
|
||||
assert result.active is True
|
||||
|
||||
def test_replace_given_name(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.givenName == "NewFirst"
|
||||
assert result.name.familyName == "User"
|
||||
|
||||
def test_replace_family_name(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.familyName == "NewLast"
|
||||
|
||||
def test_replace_username(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
assert result.userName == "new@example.com"
|
||||
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -102,7 +99,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_multiple_operations(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op("active", False),
|
||||
_replace_op("name.givenName", "Updated"),
|
||||
@@ -115,7 +112,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_case_insensitive_path(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("Active", False)], user)
|
||||
result = apply_user_patch([_replace_op("Active", False)], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_original_not_mutated(self) -> None:
|
||||
@@ -128,22 +125,15 @@ class TestApplyUserPatch:
|
||||
with pytest.raises(ScimPatchError, match="Unsupported path"):
|
||||
apply_user_patch([_replace_op("unknownField", "value")], user)
|
||||
|
||||
def test_remove_op_clears_field(self) -> None:
|
||||
"""Remove op should clear the target field (not raise)."""
|
||||
user = _make_user(externalId="ext-123")
|
||||
result, _ = apply_user_patch([_remove_op("externalId")], user)
|
||||
assert result.externalId is None
|
||||
|
||||
def test_remove_unsupported_path_raises(self) -> None:
|
||||
"""Remove op on unsupported path (e.g. 'active') should raise."""
|
||||
def test_remove_op_on_user_raises(self) -> None:
|
||||
user = _make_user()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
|
||||
with pytest.raises(ScimPatchError, match="Unsupported operation"):
|
||||
apply_user_patch([_remove_op("active")], user)
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
@@ -153,7 +143,7 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
"""The 'schemas' key in a value dict should be silently ignored."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -171,7 +161,7 @@ class TestApplyUserPatch:
|
||||
def test_okta_deactivation_payload(self) -> None:
|
||||
"""Exact Okta deactivation payload: path-less replace with id + active."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -186,7 +176,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_replace_displayname(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[_replace_op("displayName", "New Display Name")], user
|
||||
)
|
||||
assert result.displayName == "New Display Name"
|
||||
@@ -197,7 +187,7 @@ class TestApplyUserPatch:
|
||||
"""Okta sends id/schemas/meta alongside actual changes — complex types
|
||||
(lists, nested dicts) must not cause Pydantic validation errors."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -217,101 +207,9 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_add_operation_works_like_replace(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
assert result.externalId == "ext-456"
|
||||
|
||||
def test_entra_capitalized_replace_op(self) -> None:
|
||||
"""Entra ID sends ``"Replace"`` instead of ``"replace"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Replace", path="active", value=False) # type: ignore[arg-type]
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_entra_capitalized_add_op(self) -> None:
|
||||
"""Entra ID sends ``"Add"`` instead of ``"add"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Add", path="externalId", value="ext-999") # type: ignore[arg-type]
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.externalId == "ext-999"
|
||||
|
||||
def test_entra_enterprise_extension_handled(self) -> None:
|
||||
"""Entra sends the enterprise extension URN as a key in path-less
|
||||
PATCH value dicts — enterprise data should be captured in ent_data."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
# Simulate Entra including the enterprise extension URN as extra data
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_ENTRA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_okta_handles_enterprise_extension_urn(self) -> None:
|
||||
"""Enterprise extension URN paths are handled universally, even
|
||||
for Okta — the data is captured in the enterprise data dict."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_emails_primary_eq_true_value(self) -> None:
|
||||
"""emails[primary eq true].value should update the primary email entry."""
|
||||
user = _make_user(
|
||||
emails=[ScimEmail(value="old@example.com", type="work", primary=True)]
|
||||
)
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op("emails[primary eq true].value", "new@example.com")], user
|
||||
)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert result.userName == "test@example.com"
|
||||
assert len(result.emails) == 1
|
||||
assert result.emails[0].value == "new@example.com"
|
||||
assert result.emails[0].primary is True
|
||||
|
||||
def test_enterprise_urn_department_path(self) -> None:
|
||||
"""Dotted enterprise URN path should set department in ent_data."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
"Marketing",
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["department"] == "Marketing"
|
||||
|
||||
def test_enterprise_urn_manager_path(self) -> None:
|
||||
"""Dotted enterprise URN path for manager should set manager."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager",
|
||||
ScimPatchResourceValue.model_validate({"value": "boss-id"}),
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["manager"] == "boss-id"
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
|
||||
@@ -2,8 +2,6 @@ from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
@@ -11,10 +9,7 @@ from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.entra import _ENTRA_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
|
||||
@@ -44,7 +39,9 @@ class TestOktaProvider:
|
||||
assert OktaProvider().name == "okta"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
assert OktaProvider().ignored_patch_paths == COMMON_IGNORED_PATCH_PATHS
|
||||
assert OktaProvider().ignored_patch_paths == frozenset(
|
||||
{"id", "schemas", "meta"}
|
||||
)
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = OktaProvider()
|
||||
@@ -63,12 +60,6 @@ class TestOktaProvider:
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def test_build_user_resource_has_core_schema_only(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
assert result.schemas == [SCIM_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_with_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
@@ -170,42 +161,6 @@ class TestOktaProvider:
|
||||
assert result.members == []
|
||||
|
||||
|
||||
class TestEntraProvider:
|
||||
def test_name(self) -> None:
|
||||
assert EntraProvider().name == "entra"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
paths = EntraProvider().ignored_patch_paths
|
||||
assert paths == _ENTRA_IGNORED_PATCH_PATHS
|
||||
# Enterprise extension URN is now handled (not ignored)
|
||||
assert paths >= COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
def test_build_user_resource_includes_enterprise_schema(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result.schemas == [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result == ScimUserResource(
|
||||
schemas=[SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA],
|
||||
id=str(user.id),
|
||||
externalId="ext-entra-1",
|
||||
userName="test@example.com",
|
||||
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
|
||||
displayName="Test User",
|
||||
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
|
||||
active=True,
|
||||
groups=[],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultProvider:
|
||||
def test_returns_okta(self) -> None:
|
||||
provider = get_default_provider()
|
||||
|
||||
@@ -16,7 +16,7 @@ from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
@@ -28,8 +28,6 @@ from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
@@ -53,9 +51,9 @@ class TestListUsers:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
|
||||
def test_returns_users_with_scim_shape(
|
||||
self,
|
||||
@@ -79,10 +77,10 @@ class TestListUsers:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 1
|
||||
assert len(parsed.Resources) == 1
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
assert len(result.Resources) == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert resource.userName == "Alice@example.com"
|
||||
assert resource.externalId == "ext-abc"
|
||||
@@ -148,9 +146,9 @@ class TestGetUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "alice@example.com"
|
||||
assert resource.id == str(user.id)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "alice@example.com"
|
||||
assert result.id == str(user.id)
|
||||
|
||||
def test_invalid_uuid_returns_404(
|
||||
self,
|
||||
@@ -209,8 +207,8 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.userName == "new@example.com"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "new@example.com"
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -316,8 +314,8 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.externalId == "ext-123"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.externalId == "ext-123"
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
|
||||
|
||||
@@ -346,7 +344,7 @@ class TestReplaceUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -414,15 +412,9 @@ class TestReplaceUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(
|
||||
user.id,
|
||||
None,
|
||||
scim_username="test@example.com",
|
||||
fields=ScimMappingFields(
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
),
|
||||
user.id, None, scim_username="test@example.com"
|
||||
)
|
||||
|
||||
|
||||
@@ -456,7 +448,7 @@ class TestPatchUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
@@ -515,7 +507,7 @@ class TestPatchUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
# Verify the update_user call received the new display name
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["personal_name"] == "New Display Name"
|
||||
@@ -613,12 +605,10 @@ class TestDeleteUser:
|
||||
class TestScimNameToStr:
|
||||
"""Tests for _scim_name_to_str helper."""
|
||||
|
||||
def test_prefers_formatted_over_components(self) -> None:
|
||||
"""When client provides formatted, use it — the client knows what it wants."""
|
||||
name = ScimName(
|
||||
givenName="Jane", familyName="Smith", formatted="Dr. Jane Smith"
|
||||
)
|
||||
assert _scim_name_to_str(name) == "Dr. Jane Smith"
|
||||
def test_prefers_given_family_over_formatted(self) -> None:
|
||||
"""Okta may send stale formatted while updating givenName/familyName."""
|
||||
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
|
||||
assert _scim_name_to_str(name) == "Jane Smith"
|
||||
|
||||
def test_given_name_only(self) -> None:
|
||||
name = ScimName(givenName="Jane")
|
||||
@@ -663,9 +653,9 @@ class TestEmailCasePreservation:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
def test_get_preserves_username_case(
|
||||
self,
|
||||
@@ -691,6 +681,6 @@ class TestEmailCasePreservation:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Tests for PythonTool availability based on server_enabled flag.
|
||||
|
||||
Verifies that PythonTool reports itself as unavailable when either:
|
||||
- CODE_INTERPRETER_BASE_URL is not set, or
|
||||
- CodeInterpreterServer.server_enabled is False in the database.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
None,
|
||||
)
|
||||
def test_python_tool_unavailable_without_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"",
|
||||
)
|
||||
def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when server_enabled is False
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = False
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Available when both conditions are met
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
@@ -487,16 +487,7 @@ services:
|
||||
|
||||
code-interpreter:
|
||||
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
|
||||
entrypoint: ["/bin/bash", "-c"]
|
||||
command: >
|
||||
"
|
||||
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
|
||||
exec bash ./entrypoint.sh code-interpreter-api;
|
||||
else
|
||||
echo 'Skipping code interpreter';
|
||||
exec tail -f /dev/null;
|
||||
fi
|
||||
"
|
||||
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- path: .env
|
||||
|
||||
@@ -69,6 +69,4 @@ services:
|
||||
inference_model_server:
|
||||
profiles: ["inference"]
|
||||
|
||||
# Code interpreter is not needed in minimal mode.
|
||||
code-interpreter:
|
||||
profiles: ["code-interpreter"]
|
||||
code-interpreter: {}
|
||||
|
||||
@@ -315,16 +315,7 @@ services:
|
||||
|
||||
code-interpreter:
|
||||
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
|
||||
entrypoint: ["/bin/bash", "-c"]
|
||||
command: >
|
||||
"
|
||||
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
|
||||
exec bash ./entrypoint.sh code-interpreter-api;
|
||||
else
|
||||
echo 'Skipping code interpreter';
|
||||
exec tail -f /dev/null;
|
||||
fi
|
||||
"
|
||||
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- path: .env
|
||||
|
||||
@@ -352,16 +352,7 @@ services:
|
||||
|
||||
code-interpreter:
|
||||
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
|
||||
entrypoint: ["/bin/bash", "-c"]
|
||||
command: >
|
||||
"
|
||||
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
|
||||
exec bash ./entrypoint.sh code-interpreter-api;
|
||||
else
|
||||
echo 'Skipping code interpreter';
|
||||
exec tail -f /dev/null;
|
||||
fi
|
||||
"
|
||||
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- path: .env
|
||||
|
||||
@@ -527,16 +527,7 @@ services:
|
||||
|
||||
code-interpreter:
|
||||
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
|
||||
entrypoint: ["/bin/bash", "-c"]
|
||||
command: >
|
||||
"
|
||||
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
|
||||
exec bash ./entrypoint.sh code-interpreter-api;
|
||||
else
|
||||
echo 'Skipping code interpreter';
|
||||
exec tail -f /dev/null;
|
||||
fi
|
||||
"
|
||||
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- path: .env
|
||||
|
||||
@@ -19,6 +19,6 @@ dependencies:
|
||||
version: 5.4.0
|
||||
- name: code-interpreter
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
version: 0.2.1
|
||||
digest: sha256:aedc211d9732c934be8b79735b62f8caa9bcd235e03fd0dd10b49e0a13ed15b7
|
||||
generated: "2026-02-20T11:19:47.957449-08:00"
|
||||
version: 0.3.0
|
||||
digest: sha256:cf8f01906d46034962c6ce894770621ee183ac761e6942951118aeb48540eddd
|
||||
generated: "2026-02-24T10:59:38.78318-08:00"
|
||||
|
||||
@@ -45,6 +45,6 @@ dependencies:
|
||||
repository: https://charts.min.io/
|
||||
condition: minio.enabled
|
||||
- name: code-interpreter
|
||||
version: 0.2.1
|
||||
version: 0.3.0
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
condition: codeInterpreter.enabled
|
||||
|
||||
@@ -957,7 +957,7 @@ minio:
|
||||
|
||||
# Code Interpreter - Python code execution service (beta feature)
|
||||
codeInterpreter:
|
||||
enabled: false # Disabled by default (beta feature)
|
||||
enabled: true
|
||||
|
||||
replicaCount: 1
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.6.0",
|
||||
"onyx-devtools==0.6.1",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
@@ -222,6 +222,7 @@ ods run-ci 7353
|
||||
### `cherry-pick` - Backport Commits to Release Branches
|
||||
|
||||
Cherry-pick one or more commits to release branches and automatically create PRs.
|
||||
Cherry-pick PRs created by this command are labeled `cherry-pick 🍒`.
|
||||
|
||||
```shell
|
||||
ods cherry-pick <commit-sha> [<commit-sha>...] [--release <version>]
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/prompt"
|
||||
)
|
||||
|
||||
const cherryPickPRLabel = "cherry-pick 🍒"
|
||||
|
||||
// CherryPickOptions holds options for the cherry-pick command
|
||||
type CherryPickOptions struct {
|
||||
Releases []string
|
||||
@@ -510,6 +512,7 @@ func createCherryPickPR(headBranch, baseBranch, title string, commitSHAs, commit
|
||||
"--head", headBranch,
|
||||
"--title", title,
|
||||
"--body", body,
|
||||
"--label", cherryPickPRLabel,
|
||||
}
|
||||
|
||||
for _, assignee := range assignees {
|
||||
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -4654,7 +4654,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.0" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.1" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4759,20 +4759,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.6.0"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/f9/79d66c1f06e4d1dca0a9df30afcd65ec1a69219fdf17c45349396d1ec668/onyx_devtools-0.6.0-py3-none-any.whl", hash = "sha256:26049075a6d3eb794f44c1bbe55a7cfc0c5427de681ed29319064e2deb956a15", size = 3777572, upload-time = "2026-02-19T23:05:51.823Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/37/0abff5ab8d79c90f9d57eeaf4998f668145b01e81da0307df56c3b15d16c/onyx_devtools-0.6.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a7c00f2f1924c231b2480edcd3b6aa83398e13e4587c213fe1c97e0f6d3cfce1", size = 3822965, upload-time = "2026-02-19T23:06:02.992Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/79/a8c23e456b7f1bb4cb741875af6c323fba11d5ef1ba121ea8b44587c236f/onyx_devtools-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0e67fc47dfffb510826a6487dd5029a65b4a5b3f8a42e0e1208b6faee353518c", size = 3570391, upload-time = "2026-02-19T23:05:48.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/c5/d166bf2c98b80fd83d76abe88e57d63a8cb55880ba40a3d34c831361e3cf/onyx_devtools-0.6.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0fdbd085f82788b900620424798d04dc1b10c3b1baf9be821ac178adc41c6858", size = 3432611, upload-time = "2026-02-19T23:05:51.924Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/8e/c53fb7f7781acbf37ca80ebcee5d1274d54c6d853606adefc517df715f9a/onyx_devtools-0.6.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:3915ad5ea245e597a8ad91bd2ba5efc2b6a336ca59c7f3670bd89530cc9ab00f", size = 3777586, upload-time = "2026-02-19T23:05:51.877Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/57/194ded4aa5151d96911b021829e015370b4f1fc7493ac584d445fd96f97b/onyx_devtools-0.6.0-py3-none-win_amd64.whl", hash = "sha256:478cdae03ae2e797345396397318446622c7472df0a7d9dbd58d3e96489198b2", size = 3871835, upload-time = "2026-02-19T23:05:51.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/e9/cc7d204b9b1103b2f33f8f62d29076083f40f44697b398e83b3d44daca23/onyx_devtools-0.6.0-py3-none-win_arm64.whl", hash = "sha256:4bff060fd5f017ddceaf753252e0bc16699922d9a0a88506a56505aad4580824", size = 3492854, upload-time = "2026-02-19T23:05:51.856Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/3c/fc0c152ecc403b8d4c929eacc7ea4c3d6cba2094f3cfa51d9e5c4d3bda3d/onyx_devtools-0.6.1-py3-none-any.whl", hash = "sha256:a9ad90ca4536ebe9aaeb604f82c418f3fd148100f14cca7749df0d076ee5c4b0", size = 3781440, upload-time = "2026-02-25T00:59:03.565Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/1c/2df5a06eed5490057f0852153940142f9987ff9b865c9c185b733fa360b1/onyx_devtools-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:769a656737e2389312e8e24bf3e9dd559dcb00160f323228dfe34d005ab47af3", size = 3827421, upload-time = "2026-02-25T00:58:59.672Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/e3/389644eb9ba0a3cfa975cc015a48140702b05abc9093542b2a3ba6cc5cc1/onyx_devtools-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93886332e97e6efa5f3d7a1d1e4facf1442d301df379f65dfc2a328ed43c8f39", size = 3573060, upload-time = "2026-02-25T00:59:02.582Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/fe/dd0f32e08f7e7fb1861a28b82431e0a43cf6ab33e04fb2938f4ee20c891b/onyx_devtools-0.6.1-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:cf896e420c78c08c541135473627ffcab0a0156e0e462e71bcb476f560c324fa", size = 3435936, upload-time = "2026-02-25T00:59:02.313Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/3a/4376cba6adcf86b9fc55f146493450955497d988920eaa37a8aec9f9f897/onyx_devtools-0.6.1-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:4cb5a1b44a4e74c2fc68164a5caa34bce3f6d2dd5639e48438c1d04f09c4c7c6", size = 3781457, upload-time = "2026-02-25T00:59:02.126Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/0d/d2ecf7edc02354d16d9a1d9bd7d8d35f46cdde08b86635ba02075e4d3c7c/onyx_devtools-0.6.1-py3-none-win_amd64.whl", hash = "sha256:0c6c6a667851b9ab215980f1b391216bc2f157c8a29d0cfa96c32c6d10116a5c", size = 3875146, upload-time = "2026-02-25T00:59:02.364Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/c3/04783dcfad36b18f48befb6d85bf4f9a9f36fd4cd6e08077676c72c9c504/onyx_devtools-0.6.1-py3-none-win_arm64.whl", hash = "sha256:f095e58b4dad0671c7127a452c5d5f411f55070ebf586a2e47f9193ab753ce44", size = 3496971, upload-time = "2026-02-25T00:59:17.98Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -394,7 +394,7 @@
|
||||
}
|
||||
.interactive[data-interactive-base-variant="select"][data-disabled] {
|
||||
@apply bg-transparent;
|
||||
--interactive-foreground: var(--text-02);
|
||||
--interactive-foreground: var(--text-01);
|
||||
}
|
||||
.interactive[data-interactive-base-variant="select"][data-selected="true"][data-disabled] {
|
||||
--interactive-foreground: var(--action-link-03);
|
||||
|
||||
@@ -119,10 +119,10 @@ function HorizontalInputLayout({
|
||||
justifyContent="between"
|
||||
alignItems={center ? "center" : "start"}
|
||||
>
|
||||
<div className="flex flex-col self-stretch flex-[2]">
|
||||
<div className="flex flex-col flex-1 self-stretch">
|
||||
<TitleLayout {...titleLayoutProps} />
|
||||
</div>
|
||||
<div className="flex flex-col flex-[1] items-end">{children}</div>
|
||||
<div className="flex flex-col items-end">{children}</div>
|
||||
</Section>
|
||||
{name && <ErrorLayout name={name} />}
|
||||
</Section>
|
||||
|
||||
Reference in New Issue
Block a user