Compare commits

..

17 Commits

Author SHA1 Message Date
Dane Urban
1ab44a2c66 . 2026-02-25 00:08:02 -08:00
Dane Urban
65b74b974b . 2026-02-24 20:13:48 -08:00
roshan
784a99e24a updated demo data (#8748) 2026-02-24 19:59:46 -08:00
Justin Tahara
da1f5a11f4 chore(cherry-pick): Alerting on Failed Cherry-Picks (#8744) 2026-02-25 02:09:19 +00:00
Justin Tahara
5633805890 chore(devtools): Upgrade ods from 0.6.0 -> 0.6.1 (#8743) 2026-02-25 02:01:20 +00:00
Danelegend
0817b45ae1 feat: Get code interpreter config route (#8739) 2026-02-25 01:49:30 +00:00
Justin Tahara
af0e4bdebc fix(slack): Cleaning up URL Links (#8569) 2026-02-25 01:42:12 +00:00
Justin Tahara
4cd2320732 chore(cherry-pick): Add Github Label for PRs (#8736) 2026-02-25 00:46:12 +00:00
Danelegend
90a361f0e1 feat: code interpreter routes (#8670) 2026-02-24 16:27:10 -08:00
Justin Tahara
194efde97b chore(llm): Scaffolding for Nightly LLM Tests (#8704) 2026-02-25 00:06:24 +00:00
Danelegend
d922a42262 feat: code interpreter docker default deploy (#8672) 2026-02-24 23:51:19 +00:00
Danelegend
f00c3a486e feat: default deploy code interpreter - helm & bump version 0.3.0 (#8685) 2026-02-24 23:40:46 +00:00
Danelegend
192080c9e4 feat: default deploy code interpreter - restart_script (#8686) 2026-02-24 23:40:36 +00:00
Justin Tahara
c5787dc073 chore(image): Update test to be for Dall E 3 instead of 2 (#8732) 2026-02-24 22:53:31 +00:00
Justin Tahara
d424d6462c fix(sanitization): Centralizing DB Filters (#8730) 2026-02-24 22:28:25 +00:00
Jamison Lahman
ecea86deb6 chore(fe): only left input items flex (#8734) 2026-02-24 22:25:04 +00:00
Jamison Lahman
a5c1f50a8a chore(fe): update disabled "select" button color (#8733) 2026-02-24 22:03:52 +00:00
51 changed files with 1377 additions and 2089 deletions

View File

@@ -11,6 +11,11 @@ permissions:
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
@@ -75,10 +80,82 @@ jobs:
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -116,7 +116,6 @@ jobs:
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF

View File

@@ -34,7 +34,6 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
@@ -129,19 +128,12 @@ class ScimDAL(DAL):
external_id: str,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
department=f.department,
manager=f.manager,
given_name=f.given_name,
family_name=f.family_name,
scim_emails_json=f.scim_emails_json,
)
self._session.add(mapping)
self._session.flush()
@@ -319,14 +311,8 @@ class ScimDAL(DAL):
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
@@ -334,18 +320,11 @@ class ScimDAL(DAL):
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)

View File

@@ -26,14 +26,14 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
@@ -41,8 +41,6 @@ from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.base import serialize_emails
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -50,28 +48,15 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
media_type = "application/scim+json"
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
@@ -101,39 +86,15 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
@scim_router.get("/ResourceTypes")
def get_resource_types() -> ScimJSONResponse:
"""List available SCIM resource types (RFC 7643 §6).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(resources),
"Resources": [
r.model_dump(exclude_none=True, by_alias=True) for r in resources
],
}
)
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
@scim_router.get("/Schemas")
def get_schemas() -> ScimJSONResponse:
"""Return SCIM schema definitions (RFC 7643 §7).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(schemas),
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
}
)
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
# ---------------------------------------------------------------------------
@@ -141,45 +102,15 @@ def get_schemas() -> ScimJSONResponse:
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
def _scim_error_response(status: int, detail: str) -> JSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
body = ScimError(status=str(status), detail=detail)
return ScimJSONResponse(
return JSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _parse_excluded_attributes(raw: str | None) -> set[str]:
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
Returns a set of lowercased attribute names to omit from responses.
"""
if not raw:
return set()
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
def _apply_exclusions(
resource: ScimUserResource | ScimGroupResource,
excluded: set[str],
) -> dict:
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
to reduce response payload size. We strip those fields after serialization
so the rest of the pipeline doesn't need to know about them.
"""
data = resource.model_dump(exclude_none=True, by_alias=True)
for attr in excluded:
# Match case-insensitively against the camelCase field names
keys_to_remove = [k for k in data if k.lower() == attr]
for k in keys_to_remove:
del data[k]
return data
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -193,7 +124,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
@@ -213,63 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
# If the client explicitly provides ``formatted``, prefer it — the client
# knows what display string it wants. Otherwise build from components.
if name.formatted:
return name.formatted
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or None
def _scim_resource_response(
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
status_code: int = 200,
) -> ScimJSONResponse:
"""Serialize a SCIM resource as ``application/scim+json``."""
content = resource.model_dump(exclude_none=True, by_alias=True)
return ScimJSONResponse(
status_code=status_code,
content=content,
)
def _extract_enterprise_fields(
resource: ScimUserResource,
) -> tuple[str | None, str | None]:
"""Extract department and manager from enterprise extension."""
ext = resource.enterprise_extension
if not ext:
return None, None
department = ext.department
manager = ext.manager.value if ext.manager else None
return department, manager
def _mapping_to_fields(
mapping: ScimUserMapping | None,
) -> ScimMappingFields | None:
"""Extract round-trip fields from a SCIM user mapping."""
if not mapping:
return None
return ScimMappingFields(
department=mapping.department,
manager=mapping.manager,
given_name=mapping.given_name,
family_name=mapping.family_name,
scim_emails_json=mapping.scim_emails_json,
)
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
"""Build mapping fields from an incoming SCIM user resource."""
department, manager = _extract_enterprise_fields(resource)
return ScimMappingFields(
department=department,
manager=manager,
given_name=resource.name.givenName if resource.name else None,
family_name=resource.name.familyName if resource.name else None,
scim_emails_json=serialize_emails(resource.emails),
)
return parts or name.formatted
# ---------------------------------------------------------------------------
@@ -280,13 +158,12 @@ def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | ScimJSONResponse:
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -308,66 +185,42 @@ def list_users(
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
for user, mapping in users_with_mappings
]
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
response = ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
)
data = response.model_dump(exclude_none=True)
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
return ScimJSONResponse(content=data)
list_resp = ScimListResponse(
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
return _scim_resource_response(list_resp)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
resource = provider.build_user_resource(
return provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
@@ -375,7 +228,7 @@ def create_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -417,25 +270,13 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
external_id=external_id, user_id=user.id, scim_username=scim_username
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
return provider.build_user_resource(user, external_id, scim_username=scim_username)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -445,13 +286,13 @@ def replace_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
@@ -472,24 +313,15 @@ def replace_user(
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.sync_user_external_id(
user.id,
new_external_id,
scim_username=scim_username,
fields=fields,
)
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
fields=fields,
)
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
@@ -500,7 +332,7 @@ def patch_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
@@ -510,25 +342,23 @@ def patch_user(
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current_fields = _mapping_to_fields(mapping)
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
fields=current_fields,
)
try:
patched, ent_data = apply_user_patch(
patched = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
@@ -563,35 +393,17 @@ def patch_user(
personal_name=personal_name,
)
# Build updated fields by merging PATCH enterprise data with current values
cf = current_fields or ScimMappingFields()
fields = ScimMappingFields(
department=ent_data.get("department", cf.department),
manager=ent_data.get("manager", cf.manager),
given_name=patched.name.givenName if patched.name else cf.given_name,
family_name=patched.name.familyName if patched.name else cf.family_name,
scim_emails_json=(
serialize_emails(patched.emails) if patched.emails else cf.scim_emails_json
),
)
dal.sync_user_external_id(
user.id,
patched.externalId,
scim_username=new_scim_username,
fields=fields,
user.id, patched.externalId, scim_username=new_scim_username
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
fields=fields,
)
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
@@ -600,29 +412,25 @@ def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | ScimJSONResponse:
) -> Response | JSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
A second DELETE returns 404 per RFC 7644 §3.6.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
user = result
# If no SCIM mapping exists, the user was already deleted from
# SCIM's perspective — return 404 per RFC 7644 §3.6.
mapping = dal.get_user_mapping_by_user_id(user.id)
if not mapping:
return _scim_error_response(404, f"User {user_id} not found")
dal.deactivate_user(user)
dal.delete_user_mapping(mapping.id)
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
dal.commit()
@@ -634,7 +442,7 @@ def delete_user(
# ---------------------------------------------------------------------------
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
@@ -689,13 +497,12 @@ def _validate_and_parse_members(
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | ScimJSONResponse:
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -715,59 +522,37 @@ def list_groups(
for group, ext_id in groups_with_ext_ids
]
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
response = ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
)
data = response.model_dump(exclude_none=True)
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
return ScimJSONResponse(content=data)
return _scim_resource_response(
ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
resource = provider.build_group_resource(
return provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
@@ -775,7 +560,7 @@ def create_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -811,10 +596,7 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return _scim_resource_response(
provider.build_group_resource(db_group, members, external_id),
status_code=201,
)
return provider.build_group_resource(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -824,13 +606,13 @@ def replace_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result
@@ -845,9 +627,7 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return _scim_resource_response(
provider.build_group_resource(group, members, group_resource.externalId)
)
return provider.build_group_resource(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -857,7 +637,7 @@ def patch_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
@@ -866,7 +646,7 @@ def patch_group(
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result
@@ -905,9 +685,7 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return _scim_resource_response(
provider.build_group_resource(group, members, patched.externalId)
)
return provider.build_group_resource(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
@@ -915,13 +693,13 @@ def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | ScimJSONResponse:
) -> Response | JSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
if isinstance(result, JSONResponse):
return result
group = result

View File

@@ -7,14 +7,12 @@ SCIM protocol schemas follow the wire format defined in:
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
# ---------------------------------------------------------------------------
@@ -33,9 +31,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
SCIM_ENTERPRISE_USER_SCHEMA = (
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
# ---------------------------------------------------------------------------
@@ -75,36 +70,6 @@ class ScimUserGroupRef(BaseModel):
display: str | None = None
class ScimManagerRef(BaseModel):
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
value: str | None = None
class ScimEnterpriseExtension(BaseModel):
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
department: str | None = None
manager: ScimManagerRef | None = None
@dataclass
class ScimMappingFields:
"""Stored SCIM mapping fields that need to round-trip through the IdP.
Entra ID sends structured name components, email metadata, and enterprise
extension attributes that must be returned verbatim in subsequent GET
responses. These fields are persisted on ScimUserMapping and threaded
through the DAL, provider, and endpoint layers.
"""
department: str | None = None
manager: str | None = None
given_name: str | None = None
family_name: str | None = None
scim_emails_json: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -113,8 +78,6 @@ class ScimUserResource(BaseModel):
to match the SCIM wire format (not Python convention).
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
@@ -125,10 +88,6 @@ class ScimUserResource(BaseModel):
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
enterprise_extension: ScimEnterpriseExtension | None = Field(
default=None,
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
)
class ScimGroupMember(BaseModel):
@@ -206,19 +165,6 @@ class ScimPatchOperation(BaseModel):
path: str | None = None
value: ScimPatchValue = None
@field_validator("op", mode="before")
@classmethod
def normalize_op(cls, v: object) -> object:
"""Normalize op to lowercase for case-insensitive matching.
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
instead of ``"replace"``. This is safe for all providers since the
enum values are lowercase. If a future provider requires other
pre-processing quirks, move patch deserialization into the provider
subclass instead of adding more special cases here.
"""
return v.lower() if isinstance(v, str) else v
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).

View File

@@ -15,10 +15,7 @@ responsible for persisting changes.
from __future__ import annotations
import re
from dataclasses import dataclass
from dataclasses import field
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
@@ -27,51 +24,6 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
# Lowercased enterprise extension URN for case-insensitive matching
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
# Pattern for email filter paths: emails[primary eq true].value
_EMAIL_FILTER_RE = re.compile(
r"^emails\[primary\s+eq\s+true\]\.value$",
re.IGNORECASE,
)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Dispatch tables for user PATCH paths
#
# Maps lowercased SCIM path → (camelCase key, target dict name).
# "data" writes to the top-level resource dict, "name" writes to the
# name sub-object dict. This replaces the elif chains for simple fields.
# ---------------------------------------------------------------------------
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"active": ("active", "data"),
"username": ("userName", "data"),
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
}
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
"displayname": ("displayName", "data"),
}
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"displayname": ("displayName", "data"),
"externalid": ("externalId", "data"),
}
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
@@ -82,25 +34,18 @@ class ScimPatchError(Exception):
super().__init__(detail)
@dataclass
class _UserPatchCtx:
"""Bundles the mutable state for user PATCH operations."""
data: dict
name_data: dict
ent_data: dict[str, str | None] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# User PATCH
# ---------------------------------------------------------------------------
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimUserResource, dict[str, str | None]]:
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Args:
@@ -108,183 +53,79 @@ def apply_user_patch(
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified user resource, enterprise extension data dict).
The enterprise dict has keys ``"department"`` and ``"manager"``
with values set only when a PATCH operation touched them.
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
name_data = data.get("name") or {}
for op in operations:
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
_apply_user_replace(op, ctx, ignored_paths)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_user_remove(op, ctx, ignored_paths)
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data, ignored_paths)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
ctx.data["name"] = ctx.name_data
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
data["name"] = name_data
return ScimUserResource.model_validate(data)
def _apply_user_replace(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set.
# No path — value is a resource dict of top-level attributes to set
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, ctx, ignored_paths)
def _apply_user_remove(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a remove operation to user data — clears the target field."""
path = (op.path or "").lower()
if not path:
raise ScimPatchError("Remove operation requires a path")
if path in ignored_paths:
return
entry = _USER_REMOVE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = None
return
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
_set_user_field(path, op.value, data, name_data, ignored_paths)
def _set_user_field(
path: str,
value: ScimPatchValue,
ctx: _UserPatchCtx,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
*,
strict: bool = True,
) -> None:
"""Set a single field on user data by SCIM path.
Args:
strict: When ``False`` (path-less replace), unknown attributes are
silently skipped. When ``True`` (explicit path), they raise.
"""
"""Set a single field on user data by SCIM path."""
if path in ignored_paths:
return
# Simple field writes handled by the dispatch table
entry = _USER_REPLACE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = value
return
# displayName sets both the top-level field and the name.formatted sub-field
if path == "displayname":
ctx.data["displayName"] = value
ctx.name_data["formatted"] = value
elif path == "name":
if isinstance(value, dict):
for k, v in value.items():
ctx.name_data[k] = v
elif path == "emails":
if isinstance(value, list):
ctx.data["emails"] = value
elif _EMAIL_FILTER_RE.match(path):
_update_primary_email(ctx.data, value)
elif path.startswith(_ENTERPRISE_URN_LOWER):
_set_enterprise_field(path, value, ctx.ent_data)
elif not strict:
return
elif path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
data["displayName"] = value
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def _update_primary_email(data: dict, value: ScimPatchValue) -> None:
"""Update the primary email entry via emails[primary eq true].value."""
emails: list[dict] = data.get("emails") or []
for email_entry in emails:
if email_entry.get("primary"):
email_entry["value"] = value
break
else:
emails.append({"value": value, "type": "work", "primary": True})
data["emails"] = emails
def _to_dict(value: ScimPatchValue) -> dict | None:
"""Coerce a SCIM patch value to a plain dict if possible.
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
``extra="allow"``), so we also dump those back to a dict.
"""
if isinstance(value, dict):
return value
if isinstance(value, ScimPatchResourceValue):
return value.model_dump(exclude_unset=True)
return None
def _set_enterprise_field(
path: str,
value: ScimPatchValue,
ent_data: dict[str, str | None],
) -> None:
"""Handle enterprise extension URN paths or value dicts."""
# Full URN as key with dict value (path-less PATCH)
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
if path == _ENTERPRISE_URN_LOWER:
d = _to_dict(value)
if d is not None:
if "department" in d:
ent_data["department"] = d["department"]
if "manager" in d:
mgr = d["manager"]
if isinstance(mgr, dict):
ent_data["manager"] = mgr.get("value")
return
# Dotted URN path, e.g. "urn:...:user:department"
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
if suffix == "department":
ent_data["department"] = str(value) if value is not None else None
elif suffix == "manager":
d = _to_dict(value)
if d is not None:
ent_data["manager"] = d.get("value")
elif isinstance(value, str):
ent_data["manager"] = value
else:
raise ScimPatchError(f"Unsupported enterprise extension attribute '{suffix}'")
# ---------------------------------------------------------------------------
# Group PATCH
# ---------------------------------------------------------------------------
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
@@ -394,14 +235,12 @@ def _set_group_field(
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
entry = _GROUP_REPLACE_PATHS.get(path)
if entry:
key, _ = entry
data[key] = value
return
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
elif path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(

View File

@@ -2,20 +2,13 @@
from __future__ import annotations
import json
import logging
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
@@ -24,17 +17,6 @@ from onyx.db.models import User
from onyx.db.models import UserGroup
logger = logging.getLogger(__name__)
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
{
"id",
"schemas",
"meta",
}
)
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
@@ -59,22 +41,12 @@ class ScimProvider(ABC):
"""
...
@property
def user_schemas(self) -> list[str]:
"""Schema URIs to include in User resource responses.
Override in subclasses to advertise additional schemas (e.g. the
enterprise extension for Entra ID).
"""
return [SCIM_USER_SCHEMA]
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
@@ -86,48 +58,27 @@ class ScimProvider(ABC):
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
fields: Stored mapping fields that the IdP expects round-tripped.
"""
f = fields or ScimMappingFields()
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
# Use original-case userName if stored, otherwise fall back to the
# lowercased email from the User model.
username = scim_username or user.email
# Build enterprise extension when at least one value is present.
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
enterprise_ext: ScimEnterpriseExtension | None = None
schemas = list(self.user_schemas)
if f.department is not None or f.manager is not None:
manager_ref = (
ScimManagerRef(value=f.manager) if f.manager is not None else None
)
enterprise_ext = ScimEnterpriseExtension(
department=f.department,
manager=manager_ref,
)
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
name = self.build_scim_name(user, f)
emails = _deserialize_emails(f.scim_emails_json, username)
resource = ScimUserResource(
schemas=schemas,
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=username,
name=name,
name=self._build_scim_name(user),
displayName=user.personal_name,
emails=emails,
emails=[ScimEmail(value=username, type="work", primary=True)],
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
resource.enterprise_extension = enterprise_ext
return resource
def build_group_resource(
self,
@@ -147,24 +98,9 @@ class ScimProvider(ABC):
meta=ScimMeta(resourceType="Group"),
)
def build_scim_name(
self,
user: User,
fields: ScimMappingFields,
) -> ScimName | None:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name,
familyName=fields.family_name,
formatted=user.personal_name,
)
@staticmethod
def _build_scim_name(user: User) -> ScimName | None:
"""Extract SCIM name components from a user's personal name."""
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
@@ -175,27 +111,6 @@ class ScimProvider(ABC):
)
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
"""Deserialize stored email entries or build a default work email."""
if stored_json:
try:
entries = json.loads(stored_json)
if isinstance(entries, list) and entries:
return [ScimEmail(**e) for e in entries]
except (json.JSONDecodeError, TypeError):
logger.warning(
"Corrupt scim_emails_json, falling back to default: %s", stored_json
)
return [ScimEmail(value=username, type="work", primary=True)]
def serialize_emails(emails: list[ScimEmail]) -> str | None:
"""Serialize SCIM email entries to JSON for storage."""
if not emails:
return None
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.

View File

@@ -1,36 +0,0 @@
"""Entra ID (Azure AD) SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
class EntraProvider(ScimProvider):
"""Entra ID (Azure AD) SCIM provider.
Entra behavioral notes:
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
— handled by ``ScimPatchOperation.normalize_op`` validator.
- Sends the enterprise extension URN as a key in path-less PATCH value
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
store department/manager values.
- Expects the enterprise extension schema in ``schemas`` arrays and
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
"""
@property
def name(self) -> str:
return "entra"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return _ENTRA_IGNORED_PATCH_PATHS
@property
def user_schemas(self) -> list[str]:
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
@@ -23,4 +22,4 @@ class OktaProvider(ScimProvider):
@property
def ignored_patch_paths(self) -> frozenset[str]:
return COMMON_IGNORED_PATCH_PATHS
return frozenset({"id", "schemas", "meta"})

View File

@@ -4,7 +4,6 @@ Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
@@ -21,9 +20,6 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
"schemaExtensions": [
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
],
}
)
@@ -108,31 +104,6 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
],
)
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_ENTERPRISE_USER_SCHEMA,
name="EnterpriseUser",
description="Enterprise User extension (RFC 7643 §4.3)",
attributes=[
ScimSchemaAttribute(
name="department",
type="string",
description="Department.",
),
ScimSchemaAttribute(
name="manager",
type="complex",
description="The user's manager.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Manager user ID.",
),
],
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",

View File

@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -156,36 +158,7 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
cleaned_batch.append(sanitize_document_for_postgres(doc))
return cleaned_batch
@@ -602,10 +575,13 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch,
nodes=hierarchy_node_batch_cleaned,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -624,7 +600,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -0,0 +1,21 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import CodeInterpreterServer
def fetch_code_interpreter_server(
db_session: Session,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
return server
def update_code_interpreter_server_enabled(
db_session: Session,
enabled: bool,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
server.server_enabled = enabled
db_session.commit()
return server

View File

@@ -49,6 +49,7 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
@@ -228,6 +229,8 @@ def index_doc_batch_prepare(
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = sanitize_documents_for_postgres(documents)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]

View File

@@ -0,0 +1,150 @@
from typing import Any
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
def _sanitize_string(value: str) -> str:
return value.replace("\x00", "")
def _sanitize_json_like(value: Any) -> Any:
if isinstance(value, str):
return _sanitize_string(value)
if isinstance(value, list):
return [_sanitize_json_like(item) for item in value]
if isinstance(value, tuple):
return tuple(_sanitize_json_like(item) for item in value)
if isinstance(value, dict):
sanitized: dict[Any, Any] = {}
for key, nested_value in value.items():
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
return sanitized
return value
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
return expert.model_copy(
update={
"display_name": (
_sanitize_string(expert.display_name)
if expert.display_name is not None
else None
),
"first_name": (
_sanitize_string(expert.first_name)
if expert.first_name is not None
else None
),
"middle_initial": (
_sanitize_string(expert.middle_initial)
if expert.middle_initial is not None
else None
),
"last_name": (
_sanitize_string(expert.last_name)
if expert.last_name is not None
else None
),
"email": (
_sanitize_string(expert.email) if expert.email is not None else None
),
}
)
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
return ExternalAccess(
external_user_emails={
_sanitize_string(email) for email in external_access.external_user_emails
},
external_user_group_ids={
_sanitize_string(group_id)
for group_id in external_access.external_user_group_ids
},
is_public=external_access.is_public,
)
def sanitize_document_for_postgres(document: Document) -> Document:
cleaned_doc = document.model_copy(deep=True)
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
if cleaned_doc.title is not None:
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
cleaned_doc.parent_hierarchy_raw_node_id
)
cleaned_doc.metadata = {
_sanitize_string(key): (
[_sanitize_string(item) for item in value]
if isinstance(value, list)
else _sanitize_string(value)
)
for key, value in cleaned_doc.metadata.items()
}
if cleaned_doc.doc_metadata is not None:
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
if cleaned_doc.primary_owners is not None:
cleaned_doc.primary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
]
if cleaned_doc.secondary_owners is not None:
cleaned_doc.secondary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
]
if cleaned_doc.external_access is not None:
cleaned_doc.external_access = _sanitize_external_access(
cleaned_doc.external_access
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = _sanitize_string(section.link)
if section.text is not None:
section.text = _sanitize_string(section.text)
if section.image_file_id is not None:
section.image_file_id = _sanitize_string(section.image_file_id)
return cleaned_doc
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
return [sanitize_document_for_postgres(document) for document in documents]
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
cleaned_node = node.model_copy(deep=True)
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
if cleaned_node.raw_parent_id is not None:
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
if cleaned_node.link is not None:
cleaned_node.link = _sanitize_string(cleaned_node.link)
if cleaned_node.external_access is not None:
cleaned_node.external_access = _sanitize_external_access(
cleaned_node.external_access
)
return cleaned_node
def sanitize_hierarchy_nodes_for_postgres(
nodes: list[HierarchyNode],
) -> list[HierarchyNode]:
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]

View File

@@ -97,6 +97,9 @@ from onyx.server.features.web_search.api import router as web_search_router
from onyx.server.federated.api import router as federated_router
from onyx.server.kg.api import admin_router as kg_admin_router
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.code_interpreter.api import (
admin_router as code_interpreter_admin_router,
)
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
@@ -421,6 +424,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, kg_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(
application, code_interpreter_admin_router
)
include_router_with_global_prefix_prepended(
application, image_generation_admin_router
)

View File

@@ -1,14 +1,68 @@
import re
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_citation_link_destinations(message: str) -> str:
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
if "[[" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _CITATION_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(message)
normalized_message = _normalize_citation_link_destinations(message)
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result

View File

@@ -0,0 +1,47 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
admin_router = APIRouter(prefix="/admin/code-interpreter")
@admin_router.get("/health")
def get_code_interpreter_health(
_: User = Depends(current_admin_user),
) -> CodeInterpreterServerHealth:
try:
client = CodeInterpreterClient()
return CodeInterpreterServerHealth(healthy=client.health())
except ValueError:
return CodeInterpreterServerHealth(healthy=False)
@admin_router.get("")
def get_code_interpreter(
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
) -> CodeInterpreterServer:
ci_server = fetch_code_interpreter_server(db_session)
return CodeInterpreterServer(enabled=ci_server.server_enabled)
@admin_router.put("")
def update_code_interpreter(
update: CodeInterpreterServer,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_code_interpreter_server_enabled(
db_session=db_session,
enabled=update.enabled,
)

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class CodeInterpreterServer(BaseModel):
enabled: bool
class CodeInterpreterServerHealth(BaseModel):
healthy: bool

View File

@@ -98,6 +98,17 @@ class CodeInterpreterClient:
payload["files"] = files
return payload
def health(self) -> bool:
"""Check if the Code Interpreter service is healthy"""
url = f"{self.base_url}/health"
try:
response = self.session.get(url, timeout=5)
response.raise_for_status()
return response.json().get("status") == "ok"
except Exception as e:
logger.warning(f"Exception caught when checking health, e={e}")
return False
def execute(
self,
code: str,

View File

@@ -12,6 +12,7 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
from onyx.configs.constants import FileOrigin
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.file_store.utils import build_full_frontend_file_url
from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -103,8 +104,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
is_available = bool(CODE_INTERPRETER_BASE_URL)
return is_available
if not CODE_INTERPRETER_BASE_URL:
return False
server = fetch_code_interpreter_server(db_session)
return server.server_enabled
def tool_definition(self) -> dict:
return {

View File

@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.0
onyx-devtools==0.6.1
# via onyx
openai==2.14.0
# via

View File

@@ -3,8 +3,8 @@ set -e
cleanup() {
echo "Error occurred. Cleaning up..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -55,6 +55,10 @@ else
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
fi
# Start the Code Interpreter container
echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"

View File

@@ -0,0 +1,130 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.tool import ToolManager
from tests.integration.common_utils.test_models import DATestUser
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
CODE_INTERPRETER_HEALTH_URL = f"{CODE_INTERPRETER_URL}/health"
PYTHON_TOOL_NAME = "python"
def test_get_code_interpreter_health_as_admin(
admin_user: DATestUser,
) -> None:
"""Health endpoint should return a JSON object with a 'healthy' boolean."""
response = requests.get(
CODE_INTERPRETER_HEALTH_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
data = response.json()
assert "healthy" in data
assert isinstance(data["healthy"], bool)
def test_get_code_interpreter_status_as_admin(
admin_user: DATestUser,
) -> None:
"""GET endpoint should return a JSON object with an 'enabled' boolean."""
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
data = response.json()
assert "enabled" in data
assert isinstance(data["enabled"], bool)
def test_update_code_interpreter_disable_and_enable(
admin_user: DATestUser,
) -> None:
"""PUT endpoint should update the enabled flag and persist across reads."""
# Disable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": False},
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify disabled
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
assert response.json()["enabled"] is False
# Re-enable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify enabled
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
assert response.json()["enabled"] is True
def test_code_interpreter_endpoints_require_admin(
basic_user: DATestUser,
) -> None:
"""All code interpreter endpoints should reject non-admin users."""
health_response = requests.get(
CODE_INTERPRETER_HEALTH_URL,
headers=basic_user.headers,
)
assert health_response.status_code == 403
get_response = requests.get(
CODE_INTERPRETER_URL,
headers=basic_user.headers,
)
assert get_response.status_code == 403
put_response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=basic_user.headers,
)
assert put_response.status_code == 403
def test_python_tool_hidden_from_tool_list_when_disabled(
admin_user: DATestUser,
) -> None:
"""When code interpreter is disabled, the Python tool should not appear
in the GET /tool response (i.e. the frontend tool list)."""
# Disable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": False},
headers=admin_user.headers,
)
assert response.status_code == 200
# Python tool should not be in the tool list
tools = ToolManager.list_tools(user_performing_action=admin_user)
tool_names = [t.name for t in tools]
assert PYTHON_TOOL_NAME not in tool_names
# Re-enable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=admin_user.headers,
)
assert response.status_code == 200
# Python tool should reappear
tools = ToolManager.list_tools(user_performing_action=admin_user)
tool_names = [t.name for t in tools]
assert PYTHON_TOOL_NAME in tool_names

View File

@@ -0,0 +1,322 @@
import json
import os
import time
from uuid import uuid4
import pytest
import requests
from pydantic import BaseModel
from pydantic import ConfigDict
from onyx.configs import app_configs
from onyx.configs.constants import DocumentSource
from onyx.tools.constants import SEARCH_TOOL_ID
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.tool import ToolManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import ToolName
_ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
class NightlyProviderConfig(BaseModel):
model_config = ConfigDict(frozen=True)
provider: str
model_names: list[str]
api_key: str | None
api_base: str | None
custom_config: dict[str, str] | None
strict: bool
def _env_true(env_var: str, default: bool = False) -> bool:
value = os.environ.get(env_var)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _split_csv_env(env_var: str) -> list[str]:
return [
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
]
def _load_provider_config() -> NightlyProviderConfig:
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
model_names = _split_csv_env(_ENV_MODELS)
api_key = os.environ.get(_ENV_API_KEY) or None
api_base = os.environ.get(_ENV_API_BASE) or None
strict = _env_true(_ENV_STRICT, default=False)
custom_config: dict[str, str] | None = None
custom_config_json = os.environ.get(_ENV_CUSTOM_CONFIG_JSON, "").strip()
if custom_config_json:
parsed = json.loads(custom_config_json)
if not isinstance(parsed, dict):
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
custom_config = {str(key): str(value) for key, value in parsed.items()}
if provider == "ollama_chat" and api_key and not custom_config:
custom_config = {"OLLAMA_API_KEY": api_key}
return NightlyProviderConfig(
provider=provider,
model_names=model_names,
api_key=api_key,
api_base=api_base,
custom_config=custom_config,
strict=strict,
)
def _skip_or_fail(strict: bool, message: str) -> None:
if strict:
pytest.fail(message)
pytest.skip(message)
def _validate_provider_config(config: NightlyProviderConfig) -> None:
if not config.provider:
_skip_or_fail(strict=config.strict, message=f"{_ENV_PROVIDER} must be set")
if not config.model_names:
_skip_or_fail(
strict=config.strict,
message=f"{_ENV_MODELS} must include at least one model",
)
if config.provider != "ollama_chat" and not config.api_key:
_skip_or_fail(
strict=config.strict,
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
)
if config.provider == "ollama_chat" and not (
config.api_base or _default_api_base_for_provider(config.provider)
):
_skip_or_fail(
strict=config.strict,
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
)
def _assert_integration_mode_enabled() -> None:
assert (
app_configs.INTEGRATION_TESTS_MODE is True
), "Integration tests require INTEGRATION_TESTS_MODE=true."
def _seed_connector_for_search_tool(admin_user: DATestUser) -> None:
# SearchTool is only exposed when at least one non-default connector exists.
CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
def _get_internal_search_tool_id(admin_user: DATestUser) -> int:
tools = ToolManager.list_tools(user_performing_action=admin_user)
for tool in tools:
if tool.in_code_tool_id == SEARCH_TOOL_ID:
return tool.id
raise AssertionError("SearchTool must exist for this test")
def _default_api_base_for_provider(provider: str) -> str | None:
if provider == "openrouter":
return "https://openrouter.ai/api/v1"
if provider == "ollama_chat":
# host.docker.internal works when tests are running inside the integration test container.
return "http://host.docker.internal:11434"
return None
def _create_provider_payload(
provider: str,
provider_name: str,
model_name: str,
api_key: str | None,
api_base: str | None,
custom_config: dict[str, str] | None,
) -> dict:
return {
"name": provider_name,
"provider": provider,
"api_key": api_key,
"api_base": api_base,
"custom_config": custom_config,
"default_model_name": model_name,
"is_public": True,
"groups": [],
"personas": [],
"model_configurations": [{"name": model_name, "is_visible": True}],
"api_key_changed": bool(api_key),
"custom_config_changed": bool(custom_config),
}
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
list_response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
list_response.raise_for_status()
providers = list_response.json()
current_default = next(
(provider for provider in providers if provider.get("is_default_provider")),
None,
)
assert (
current_default is not None
), "Expected a default provider after setting provider as default"
assert (
current_default["id"] == provider_id
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
def _run_chat_assertions(
admin_user: DATestUser,
search_tool_id: int,
provider: str,
model_name: str,
) -> None:
last_error: str | None = None
# Retry once to reduce transient nightly flakes due provider-side blips.
for attempt in range(1, 3):
chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message=(
"Use internal_search to search for 'nightly-provider-regression-sentinel', "
"then summarize the result in one short sentence."
),
user_performing_action=admin_user,
forced_tool_ids=[search_tool_id],
)
if response.error is None:
used_internal_search = any(
used_tool.tool_name == ToolName.INTERNAL_SEARCH
for used_tool in response.used_tools
)
debug_has_internal_search = any(
debug_tool_call.tool_name == "internal_search"
for debug_tool_call in response.tool_call_debug
)
has_answer = bool(response.full_message.strip())
if used_internal_search and debug_has_internal_search and has_answer:
return
last_error = (
f"attempt={attempt} provider={provider} model={model_name} "
f"used_internal_search={used_internal_search} "
f"debug_internal_search={debug_has_internal_search} "
f"has_answer={has_answer} "
f"tool_call_debug={response.tool_call_debug}"
)
else:
last_error = (
f"attempt={attempt} provider={provider} model={model_name} "
f"stream_error={response.error.error}"
)
time.sleep(attempt)
pytest.fail(f"Chat/tool-call assertions failed: {last_error}")
def _create_and_test_provider_for_model(
admin_user: DATestUser,
config: NightlyProviderConfig,
model_name: str,
search_tool_id: int,
) -> None:
provider_name = f"nightly-{config.provider}-{uuid4().hex[:12]}"
resolved_api_base = config.api_base or _default_api_base_for_provider(
config.provider
)
provider_payload = _create_provider_payload(
provider=config.provider,
provider_name=provider_name,
model_name=model_name,
api_key=config.api_key,
api_base=resolved_api_base,
custom_config=config.custom_config,
)
test_response = requests.post(
f"{API_SERVER_URL}/admin/llm/test",
headers=admin_user.headers,
json=provider_payload,
)
assert test_response.status_code == 200, (
f"Provider test endpoint failed for provider={config.provider} "
f"model={model_name}: {test_response.status_code} {test_response.text}"
)
create_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json=provider_payload,
)
assert create_response.status_code == 200, (
f"Provider creation failed for provider={config.provider} "
f"model={model_name}: {create_response.status_code} {create_response.text}"
)
provider_id = create_response.json()["id"]
try:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
headers=admin_user.headers,
)
assert set_default_response.status_code == 200, (
f"Setting default provider failed for provider={config.provider} "
f"model={model_name}: {set_default_response.status_code} "
f"{set_default_response.text}"
)
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
_run_chat_assertions(
admin_user=admin_user,
search_tool_id=search_tool_id,
provider=config.provider,
model_name=model_name,
)
finally:
requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}",
headers=admin_user.headers,
)
def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
"""Nightly regression test for provider setup + default selection + chat tool calls."""
_assert_integration_mode_enabled()
config = _load_provider_config()
_validate_provider_config(config)
_seed_connector_for_search_tool(admin_user)
search_tool_id = _get_internal_search_tool_id(admin_user)
for model_name in config.model_names:
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)

View File

@@ -98,11 +98,6 @@ class TestScimDALUserMappings:
"external_id": "ext-1",
"user_id": user_id,
"scim_username": None,
"department": None,
"manager": None,
"given_name": None,
"family_name": None,
"scim_emails_json": None,
}
def test_delete_user_mapping(

View File

@@ -211,13 +211,16 @@ def test_openai_provider_rejects_reference_images_for_unsupported_model() -> Non
)
def test_openai_provider_rejects_multiple_reference_images_for_dalle2() -> None:
def test_openai_provider_rejects_multiple_reference_images_for_dalle3() -> None:
provider = OpenAIImageGenerationProvider(api_key="test-key")
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="does not support image edits with reference images",
):
provider.generate_image(
prompt="edit this image",
model="dall-e-2",
model="dall-e-3",
size="1024x1024",
n=1,
reference_images=[
@@ -307,17 +310,20 @@ def test_azure_provider_rejects_reference_images_for_unsupported_model() -> None
)
def test_azure_provider_rejects_multiple_reference_images_for_dalle2() -> None:
def test_azure_provider_rejects_multiple_reference_images_for_dalle3() -> None:
provider = AzureImageGenerationProvider(
api_key="test-key",
api_base="https://azure.example.com",
api_version="2024-05-01-preview",
)
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match="does not support image edits with reference images",
):
provider.generate_image(
prompt="edit this image",
model="dall-e-2",
model="dall-e-3",
size="1024x1024",
n=1,
reference_images=[

View File

@@ -0,0 +1,159 @@
from pytest import MonkeyPatch
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.enums import HierarchyNodeType
from onyx.indexing import indexing_pipeline
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_node_for_postgres
def test_sanitize_document_for_postgres_removes_nul_bytes() -> None:
document = Document(
id="doc\x00-id",
source=DocumentSource.FILE,
semantic_identifier="sem\x00-id",
title="ti\x00tle",
parent_hierarchy_raw_node_id="parent\x00-id",
sections=[TextSection(link="lin\x00k", text="te\x00xt")],
metadata={"ke\x00y": "va\x00lue", "list\x00key": ["a\x00", "b"]},
doc_metadata={
"j\x00son": {
"in\x00ner": "va\x00l",
"arr": ["x\x00", {"dee\x00p": "y\x00"}],
}
},
primary_owners=[BasicExpertInfo(display_name="Ali\x00ce", email="a\x00@x.com")],
secondary_owners=[BasicExpertInfo(first_name="Bo\x00b", last_name="Sm\x00ith")],
external_access=ExternalAccess(
external_user_emails={"user\x00@example.com"},
external_user_group_ids={"gro\x00up-1"},
is_public=False,
),
)
sanitized = sanitize_document_for_postgres(document)
assert sanitized.id == "doc-id"
assert sanitized.semantic_identifier == "sem-id"
assert sanitized.title == "title"
assert sanitized.parent_hierarchy_raw_node_id == "parent-id"
assert sanitized.sections[0].link == "link"
assert sanitized.sections[0].text == "text"
assert sanitized.metadata == {"key": "value", "listkey": ["a", "b"]}
assert sanitized.doc_metadata == {
"json": {"inner": "val", "arr": ["x", {"deep": "y"}]}
}
assert sanitized.primary_owners is not None
assert sanitized.primary_owners[0].display_name == "Alice"
assert sanitized.primary_owners[0].email == "a@x.com"
assert sanitized.secondary_owners is not None
assert sanitized.secondary_owners[0].first_name == "Bob"
assert sanitized.secondary_owners[0].last_name == "Smith"
assert sanitized.external_access is not None
assert sanitized.external_access.external_user_emails == {"user@example.com"}
assert sanitized.external_access.external_user_group_ids == {"group-1"}
# Ensure original document is not mutated
assert document.id == "doc\x00-id"
assert document.metadata == {"ke\x00y": "va\x00lue", "list\x00key": ["a\x00", "b"]}
def test_sanitize_hierarchy_node_for_postgres_removes_nul_bytes() -> None:
node = HierarchyNode(
raw_node_id="raw\x00-id",
raw_parent_id="paren\x00t-id",
display_name="fol\x00der",
link="https://exa\x00mple.com",
node_type=HierarchyNodeType.FOLDER,
external_access=ExternalAccess(
external_user_emails={"a\x00@example.com"},
external_user_group_ids={"g\x00-1"},
is_public=True,
),
)
sanitized = sanitize_hierarchy_node_for_postgres(node)
assert sanitized.raw_node_id == "raw-id"
assert sanitized.raw_parent_id == "parent-id"
assert sanitized.display_name == "folder"
assert sanitized.link == "https://example.com"
assert sanitized.external_access is not None
assert sanitized.external_access.external_user_emails == {"a@example.com"}
assert sanitized.external_access.external_user_group_ids == {"g-1"}
def test_index_doc_batch_prepare_sanitizes_before_db_ops(
monkeypatch: MonkeyPatch,
) -> None:
document = Document(
id="doc\x00id",
source=DocumentSource.FILE,
semantic_identifier="sem\x00id",
sections=[TextSection(text="content", link="li\x00nk")],
metadata={"ke\x00y": "va\x00lue"},
)
captured: dict[str, object] = {}
def _get_documents_by_ids(db_session: object, document_ids: list[str]) -> list:
_ = db_session, document_ids
return []
monkeypatch.setattr(
indexing_pipeline, "get_documents_by_ids", _get_documents_by_ids
)
def _capture_upsert_documents_in_db(**kwargs: object) -> None:
captured["upsert_documents"] = kwargs["documents"]
monkeypatch.setattr(
indexing_pipeline, "_upsert_documents_in_db", _capture_upsert_documents_in_db
)
def _capture_doc_cc_pair(*args: object) -> None:
captured["cc_pair_doc_ids"] = args[3]
monkeypatch.setattr(
indexing_pipeline,
"upsert_document_by_connector_credential_pair",
_capture_doc_cc_pair,
)
def _noop_link_hierarchy_nodes_to_documents(
db_session: object,
document_ids: list[str],
source: DocumentSource,
commit: bool,
) -> int:
_ = db_session, document_ids, source, commit
return 0
monkeypatch.setattr(
indexing_pipeline,
"link_hierarchy_nodes_to_documents",
_noop_link_hierarchy_nodes_to_documents,
)
context = indexing_pipeline.index_doc_batch_prepare(
documents=[document],
index_attempt_metadata=IndexAttemptMetadata(connector_id=1, credential_id=2),
db_session=object(), # type: ignore[arg-type]
ignore_time_skip=True,
)
assert context is not None
assert context.updatable_docs[0].id == "docid"
assert context.updatable_docs[0].semantic_identifier == "semid"
assert context.updatable_docs[0].metadata == {"key": "value"}
assert captured["cc_pair_doc_ids"] == ["docid"]
upsert_documents = captured["upsert_documents"]
assert isinstance(upsert_documents, list)
assert upsert_documents[0].id == "docid"

View File

@@ -0,0 +1,52 @@
from onyx.onyxbot.slack.formatting import _normalize_citation_link_destinations
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.utils.text_processing import decode_escapes
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
message = (
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
)
normalized = _normalize_citation_link_destinations(message)
assert (
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
== normalized
)
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
normalized = _normalize_citation_link_destinations(message)
assert message == normalized
def test_normalize_citation_link_handles_multiple_links() -> None:
message = (
"[[1]](https://example.com/(USA)%20Guide.pdf) "
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
)
normalized = _normalize_citation_link_destinations(message)
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
message = (
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
)
formatted = format_slack_message(message)
rendered = decode_escapes(remove_slack_text_interactions(formatted))
assert (
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
in rendered
)
assert "|[1]>%20Access%20ID%20Card" not in rendered

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import json
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
@@ -13,9 +12,7 @@ import pytest
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from ee.onyx.server.scim.api import ScimJSONResponse
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import ScimProvider
@@ -118,11 +115,6 @@ def make_user_mapping(**kwargs: Any) -> MagicMock:
mapping.external_id = kwargs.get("external_id", "ext-default")
mapping.user_id = kwargs.get("user_id", uuid4())
mapping.scim_username = kwargs.get("scim_username", None)
mapping.department = kwargs.get("department", None)
mapping.manager = kwargs.get("manager", None)
mapping.given_name = kwargs.get("given_name", None)
mapping.family_name = kwargs.get("family_name", None)
mapping.scim_emails_json = kwargs.get("scim_emails_json", None)
return mapping
@@ -130,35 +122,3 @@ def assert_scim_error(result: object, expected_status: int) -> None:
"""Assert *result* is a JSONResponse with the given status code."""
assert isinstance(result, JSONResponse)
assert result.status_code == expected_status
# ---------------------------------------------------------------------------
# Response parsing helpers
# ---------------------------------------------------------------------------
def parse_scim_user(result: object, *, status: int = 200) -> ScimUserResource:
"""Assert *result* is a ScimJSONResponse and parse as ScimUserResource."""
assert isinstance(
result, ScimJSONResponse
), f"Expected ScimJSONResponse, got {type(result).__name__}"
assert result.status_code == status
return ScimUserResource.model_validate(json.loads(result.body))
def parse_scim_group(result: object, *, status: int = 200) -> ScimGroupResource:
"""Assert *result* is a ScimJSONResponse and parse as ScimGroupResource."""
assert isinstance(
result, ScimJSONResponse
), f"Expected ScimJSONResponse, got {type(result).__name__}"
assert result.status_code == status
return ScimGroupResource.model_validate(json.loads(result.body))
def parse_scim_list(result: object) -> ScimListResponse:
"""Assert *result* is a ScimJSONResponse and parse as ScimListResponse."""
assert isinstance(
result, ScimJSONResponse
), f"Expected ScimJSONResponse, got {type(result).__name__}"
assert result.status_code == 200
return ScimListResponse.model_validate(json.loads(result.body))

View File

@@ -1,983 +0,0 @@
"""Comprehensive Entra ID (Azure AD) SCIM compatibility tests.
Covers the full Entra provisioning lifecycle: service discovery, user CRUD
with enterprise extension schema, group CRUD with excludedAttributes, and
all Entra-specific behavioral quirks (PascalCase ops, enterprise URN in
PATCH value dicts).
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from fastapi import Response
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_group
from ee.onyx.server.scim.api import get_resource_types
from ee.onyx.server.scim.api import get_schemas
from ee.onyx.server.scim.api import get_service_provider_config
from ee.onyx.server.scim.api import get_user
from ee.onyx.server.scim.api import list_groups
from ee.onyx.server.scim.api import list_users
from ee.onyx.server.scim.api import patch_group
from ee.onyx.server.scim.api import patch_user
from ee.onyx.server.scim.api import replace_user
from ee.onyx.server.scim.api import ScimJSONResponse
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.entra import EntraProvider
from tests.unit.onyx.server.scim.conftest import make_db_group
from tests.unit.onyx.server.scim.conftest import make_db_user
from tests.unit.onyx.server.scim.conftest import make_scim_user
from tests.unit.onyx.server.scim.conftest import make_user_mapping
from tests.unit.onyx.server.scim.conftest import parse_scim_group
from tests.unit.onyx.server.scim.conftest import parse_scim_list
from tests.unit.onyx.server.scim.conftest import parse_scim_user
@pytest.fixture
def entra_provider() -> ScimProvider:
"""An EntraProvider instance for Entra-specific endpoint tests."""
return EntraProvider()
# ---------------------------------------------------------------------------
# Service Discovery
# ---------------------------------------------------------------------------
class TestEntraServiceDiscovery:
"""Entra expects enterprise extension in discovery endpoints."""
def test_service_provider_config_advertises_patch(self) -> None:
config = get_service_provider_config()
assert config.patch.supported is True
def test_resource_types_include_enterprise_extension(self) -> None:
result = get_resource_types()
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "Resources" in parsed
user_type = next(rt for rt in parsed["Resources"] if rt["id"] == "User")
extension_schemas = [ext["schema"] for ext in user_type["schemaExtensions"]]
assert SCIM_ENTERPRISE_USER_SCHEMA in extension_schemas
def test_schemas_include_enterprise_user(self) -> None:
result = get_schemas()
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
schema_ids = [s["id"] for s in parsed["Resources"]]
assert SCIM_ENTERPRISE_USER_SCHEMA in schema_ids
def test_enterprise_schema_has_expected_attributes(self) -> None:
result = get_schemas()
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
enterprise = next(
s for s in parsed["Resources"] if s["id"] == SCIM_ENTERPRISE_USER_SCHEMA
)
attr_names = {a["name"] for a in enterprise["attributes"]}
assert "department" in attr_names
assert "manager" in attr_names
def test_service_discovery_content_type(self) -> None:
"""SCIM responses must use application/scim+json content type."""
result = get_resource_types()
assert isinstance(result, ScimJSONResponse)
assert result.media_type == "application/scim+json"
# ---------------------------------------------------------------------------
# User Lifecycle (Entra-specific)
# ---------------------------------------------------------------------------
class TestEntraUserLifecycle:
"""Test user CRUD through Entra's lens: enterprise schemas, PascalCase ops."""
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_create_user_includes_enterprise_schema(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(userName="alice@contoso.com")
result = create_user(
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
assert SCIM_USER_SCHEMA in resource.schemas
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_create_user_with_enterprise_extension(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Enterprise extension department/manager should round-trip on create."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(
userName="alice@contoso.com",
enterprise_extension=ScimEnterpriseExtension(
department="Engineering",
manager=ScimManagerRef(value="mgr-uuid-123"),
),
)
result = create_user(
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert resource.enterprise_extension is not None
assert resource.enterprise_extension.department == "Engineering"
assert resource.enterprise_extension.manager is not None
assert resource.enterprise_extension.manager.value == "mgr-uuid-123"
# Verify DAL received the enterprise fields
mock_dal.create_user_mapping.assert_called_once()
call_kwargs = mock_dal.create_user_mapping.call_args[1]
assert call_kwargs["fields"] == ScimMappingFields(
department="Engineering",
manager="mgr-uuid-123",
given_name="Test",
family_name="User",
)
def test_get_user_includes_enterprise_schema(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user(email="alice@contoso.com")
mock_dal.get_user.return_value = user
result = get_user(
user_id=str(user.id),
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
def test_get_user_returns_enterprise_extension_data(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""GET should return stored enterprise extension data."""
user = make_db_user(email="alice@contoso.com")
mock_dal.get_user.return_value = user
mapping = make_user_mapping(user_id=user.id)
mapping.department = "Sales"
mapping.manager = "mgr-456"
mock_dal.get_user_mapping_by_user_id.return_value = mapping
result = get_user(
user_id=str(user.id),
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert resource.enterprise_extension is not None
assert resource.enterprise_extension.department == "Sales"
assert resource.enterprise_extension.manager is not None
assert resource.enterprise_extension.manager.value == "mgr-456"
def test_list_users_includes_enterprise_schema(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user(email="alice@contoso.com")
mapping = make_user_mapping(external_id="entra-ext-1", user_id=user.id)
mock_dal.list_users.return_value = ([(user, mapping)], 1)
result = list_users(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
resource = parsed.Resources[0]
assert isinstance(resource, ScimUserResource)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
def test_patch_user_deactivate_with_pascal_case_replace(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Replace"`` (PascalCase) instead of ``"replace"``."""
user = make_db_user(is_active=True)
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Replace", # type: ignore[arg-type]
path="active",
value=False,
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# Mock doesn't propagate the change, so verify via the DAL call
mock_dal.update_user.assert_called_once()
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["is_active"] is False
def test_patch_user_add_external_id_with_pascal_case(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Add"`` (PascalCase) instead of ``"add"``."""
user = make_db_user()
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Add", # type: ignore[arg-type]
path="externalId",
value="entra-ext-999",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# Verify the patched externalId was synced to the DAL
mock_dal.sync_user_external_id.assert_called_once()
call_args = mock_dal.sync_user_external_id.call_args
assert call_args[0][1] == "entra-ext-999"
def test_patch_user_enterprise_extension_in_value_dict(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends enterprise extension URN as key in path-less PATCH value
dicts — enterprise data should be stored, not ignored."""
user = make_db_user()
mock_dal.get_user.return_value = user
value = ScimPatchResourceValue(active=False)
assert value.__pydantic_extra__ is not None
value.__pydantic_extra__[
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
] = {"department": "Engineering"}
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path=None,
value=value,
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# Verify active=False was applied
mock_dal.update_user.assert_called_once()
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["is_active"] is False
# Verify enterprise data was passed to DAL
mock_dal.sync_user_external_id.assert_called_once()
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
assert sync_kwargs["fields"] == ScimMappingFields(
department="Engineering",
given_name="Test",
family_name="User",
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
)
def test_patch_user_remove_external_id(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PATCH remove op should clear the target field."""
user = make_db_user()
mock_dal.get_user.return_value = user
mapping = make_user_mapping(user_id=user.id)
mapping.external_id = "ext-to-remove"
mock_dal.get_user_mapping_by_user_id.return_value = mapping
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REMOVE,
path="externalId",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# externalId should be cleared (None)
mock_dal.sync_user_external_id.assert_called_once()
call_args = mock_dal.sync_user_external_id.call_args
assert call_args[0][1] is None
def test_patch_user_emails_primary_eq_true_value(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PATCH with path emails[primary eq true].value should update
the primary email entry, not userName."""
user = make_db_user(email="old@contoso.com")
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="emails[primary eq true].value",
value="new@contoso.com",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
# userName should remain unchanged — emails and userName are separate
assert resource.userName == "old@contoso.com"
# Primary email should be updated
primary_emails = [e for e in resource.emails if e.primary]
assert len(primary_emails) == 1
assert primary_emails[0].value == "new@contoso.com"
def test_patch_user_enterprise_urn_department_path(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PATCH with dotted enterprise URN path should store department."""
user = make_db_user()
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
value="Marketing",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
mock_dal.sync_user_external_id.assert_called_once()
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
assert sync_kwargs["fields"] == ScimMappingFields(
department="Marketing",
given_name="Test",
family_name="User",
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
)
def test_replace_user_includes_enterprise_schema(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user(email="old@contoso.com")
mock_dal.get_user.return_value = user
resource = make_scim_user(
userName="new@contoso.com",
name=ScimName(givenName="New", familyName="Name"),
)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
def test_replace_user_with_enterprise_extension(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PUT with enterprise extension should store the fields."""
user = make_db_user(email="alice@contoso.com")
mock_dal.get_user.return_value = user
resource = make_scim_user(
userName="alice@contoso.com",
enterprise_extension=ScimEnterpriseExtension(
department="HR",
manager=ScimManagerRef(value="boss-id"),
),
)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
mock_dal.sync_user_external_id.assert_called_once()
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
assert sync_kwargs["fields"] == ScimMappingFields(
department="HR",
manager="boss-id",
given_name="Test",
family_name="User",
)
def test_delete_user_returns_204(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
mock_dal.get_user_mapping_by_user_id.return_value = MagicMock(id=1)
result = delete_user(
user_id=str(user.id),
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, Response)
assert result.status_code == 204
def test_double_delete_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
"""Second DELETE should return 404 — the SCIM mapping is gone."""
user = make_db_user()
mock_dal.get_user.return_value = user
# No mapping — user was already deleted from SCIM's perspective
mock_dal.get_user_mapping_by_user_id.return_value = None
result = delete_user(
user_id=str(user.id),
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
assert result.status_code == 404
def test_name_formatted_preserved_on_create(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""When name.formatted is provided, it should be used as personal_name."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(
userName="alice@contoso.com",
name=ScimName(
givenName="Alice",
familyName="Smith",
formatted="Dr. Alice Smith",
),
)
with patch(
"ee.onyx.server.scim.api._check_seat_availability", return_value=None
):
result = create_user(
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result, status=201)
# The User constructor should have received the formatted name
mock_dal.add_user.assert_called_once()
created_user = mock_dal.add_user.call_args[0][0]
assert created_user.personal_name == "Dr. Alice Smith"
# ---------------------------------------------------------------------------
# Group Lifecycle (Entra-specific)
# ---------------------------------------------------------------------------
class TestEntraGroupLifecycle:
"""Test group CRUD with Entra-specific behaviors."""
def test_get_group_standard_response(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=10, name="Contoso Engineering")
mock_dal.get_group.return_value = group
uid = uuid4()
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
result = get_group(
group_id="10",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_group(result)
assert resource.displayName == "Contoso Engineering"
assert len(resource.members) == 1
def test_list_groups_with_excluded_attributes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ?excludedAttributes=members on group list queries."""
group = make_db_group(id=10, name="Engineering")
uid = uuid4()
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
result = list_groups(
filter=None,
excludedAttributes="members",
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert parsed["totalResults"] == 1
resource = parsed["Resources"][0]
assert "members" not in resource
assert resource["displayName"] == "Engineering"
def test_get_group_with_excluded_attributes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ?excludedAttributes=members on single group GET."""
group = make_db_group(id=10, name="Engineering")
uid = uuid4()
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
result = get_group(
group_id="10",
excludedAttributes="members",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "members" not in parsed
assert parsed["displayName"] == "Engineering"
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_patch_group_add_members_with_pascal_case(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Add"`` (PascalCase) for group member additions."""
group = make_db_group(id=10)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
mock_dal.validate_member_ids.return_value = []
uid = str(uuid4())
patched = ScimGroupResource(
id="10",
displayName="Engineering",
members=[ScimGroupMember(value=uid)],
)
mock_apply.return_value = (patched, [uid], [])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Add", # type: ignore[arg-type]
path="members",
value=[ScimGroupMember(value=uid)],
)
]
)
result = patch_group(
group_id="10",
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_group(result)
mock_dal.upsert_group_members.assert_called_once()
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_patch_group_remove_member_with_pascal_case(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Remove"`` (PascalCase) for group member removals."""
group = make_db_group(id=10)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
uid = str(uuid4())
patched = ScimGroupResource(id="10", displayName="Engineering", members=[])
mock_apply.return_value = (patched, [], [uid])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Remove", # type: ignore[arg-type]
path=f'members[value eq "{uid}"]',
)
]
)
result = patch_group(
group_id="10",
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_group(result)
mock_dal.remove_group_members.assert_called_once()
# ---------------------------------------------------------------------------
# excludedAttributes (RFC 7644 §3.4.2.5)
# ---------------------------------------------------------------------------
class TestExcludedAttributes:
"""Test excludedAttributes query parameter on GET endpoints."""
def test_list_groups_excludes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
uid = uuid4()
mock_dal.list_groups.return_value = ([(group, None)], 1)
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
result = list_groups(
filter=None,
excludedAttributes="members",
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
resource = parsed["Resources"][0]
assert "members" not in resource
assert "displayName" in resource
def test_get_group_excludes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
uid = uuid4()
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
result = get_group(
group_id="1",
excludedAttributes="members",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "members" not in parsed
assert "displayName" in parsed
def test_list_users_excludes_groups(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user()
mapping = make_user_mapping(user_id=user.id)
mock_dal.list_users.return_value = ([(user, mapping)], 1)
mock_dal.get_users_groups_batch.return_value = {user.id: [(1, "Engineering")]}
result = list_users(
filter=None,
excludedAttributes="groups",
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
resource = parsed["Resources"][0]
assert "groups" not in resource
assert "userName" in resource
def test_get_user_excludes_groups(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
mock_dal.get_user_groups.return_value = [(1, "Engineering")]
result = get_user(
user_id=str(user.id),
excludedAttributes="groups",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "groups" not in parsed
assert "userName" in parsed
def test_multiple_excluded_attributes(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
result = get_group(
group_id="1",
excludedAttributes="members,externalId",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "members" not in parsed
assert "externalId" not in parsed
assert "displayName" in parsed
def test_no_excluded_attributes_returns_full_response(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
uid = uuid4()
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
result = get_group(
group_id="1",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_group(result)
assert len(resource.members) == 1
# ---------------------------------------------------------------------------
# Entra Connection Probe
# ---------------------------------------------------------------------------
class TestEntraConnectionProbe:
"""Entra sends a probe request during initial SCIM setup."""
def test_filter_for_nonexistent_user_returns_empty_list(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra probes with: GET /Users?filter=userName eq "non-existent"&count=1"""
mock_dal.list_users.return_value = ([], 0)
result = list_users(
filter='userName eq "non-existent@contoso.com"',
startIndex=1,
count=1,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
assert parsed.totalResults == 0
assert parsed.Resources == []

View File

@@ -16,6 +16,7 @@ from ee.onyx.server.scim.api import patch_group
from ee.onyx.server.scim.api import replace_group
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
@@ -24,8 +25,6 @@ from ee.onyx.server.scim.providers.base import ScimProvider
from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_group
from tests.unit.onyx.server.scim.conftest import make_scim_group
from tests.unit.onyx.server.scim.conftest import parse_scim_group
from tests.unit.onyx.server.scim.conftest import parse_scim_list
class TestListGroups:
@@ -49,9 +48,9 @@ class TestListGroups:
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
assert parsed.totalResults == 0
assert parsed.Resources == []
assert isinstance(result, ScimListResponse)
assert result.totalResults == 0
assert result.Resources == []
def test_unsupported_filter_returns_400(
self,
@@ -96,9 +95,9 @@ class TestListGroups:
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
assert parsed.totalResults == 1
resource = parsed.Resources[0]
assert isinstance(result, ScimListResponse)
assert result.totalResults == 1
resource = result.Resources[0]
assert isinstance(resource, ScimGroupResource)
assert resource.displayName == "Engineering"
assert resource.externalId == "ext-g-1"
@@ -127,9 +126,9 @@ class TestGetGroup:
db_session=mock_db_session,
)
resource = parse_scim_group(result)
assert resource.displayName == "Engineering"
assert resource.id == "5"
assert isinstance(result, ScimGroupResource)
assert result.displayName == "Engineering"
assert result.id == "5"
def test_non_integer_id_returns_404(
self,
@@ -191,8 +190,8 @@ class TestCreateGroup:
db_session=mock_db_session,
)
resource = parse_scim_group(result, status=201)
assert resource.displayName == "New Group"
assert isinstance(result, ScimGroupResource)
assert result.displayName == "New Group"
mock_dal.add_group.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -284,7 +283,7 @@ class TestCreateGroup:
db_session=mock_db_session,
)
parse_scim_group(result, status=201)
assert isinstance(result, ScimGroupResource)
mock_dal.create_group_mapping.assert_called_once()
@@ -315,7 +314,7 @@ class TestReplaceGroup:
db_session=mock_db_session,
)
parse_scim_group(result)
assert isinstance(result, ScimGroupResource)
mock_dal.update_group.assert_called_once_with(group, name="New Name")
mock_dal.replace_group_members.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -428,7 +427,7 @@ class TestPatchGroup:
db_session=mock_db_session,
)
parse_scim_group(result)
assert isinstance(result, ScimGroupResource)
mock_dal.update_group.assert_called_once_with(group, name="New Name")
def test_not_found_returns_404(
@@ -535,7 +534,7 @@ class TestPatchGroup:
db_session=mock_db_session,
)
parse_scim_group(result)
assert isinstance(result, ScimGroupResource)
mock_dal.validate_member_ids.assert_called_once()
mock_dal.upsert_group_members.assert_called_once()
@@ -615,7 +614,7 @@ class TestPatchGroup:
db_session=mock_db_session,
)
parse_scim_group(result)
assert isinstance(result, ScimGroupResource)
mock_dal.remove_group_members.assert_called_once()

View File

@@ -1,6 +1,5 @@
import pytest
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
@@ -13,11 +12,9 @@ from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.entra import EntraProvider
from ee.onyx.server.scim.providers.okta import OktaProvider
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
_ENTRA_IGNORED = EntraProvider().ignored_patch_paths
def _make_user(**kwargs: object) -> ScimUserResource:
@@ -59,36 +56,36 @@ class TestApplyUserPatch:
def test_deactivate_user(self) -> None:
user = _make_user()
result, _ = apply_user_patch([_replace_op("active", False)], user)
result = apply_user_patch([_replace_op("active", False)], user)
assert result.active is False
assert result.userName == "test@example.com"
def test_activate_user(self) -> None:
user = _make_user(active=False)
result, _ = apply_user_patch([_replace_op("active", True)], user)
result = apply_user_patch([_replace_op("active", True)], user)
assert result.active is True
def test_replace_given_name(self) -> None:
user = _make_user()
result, _ = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
assert result.name is not None
assert result.name.givenName == "NewFirst"
assert result.name.familyName == "User"
def test_replace_family_name(self) -> None:
user = _make_user()
result, _ = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
assert result.name is not None
assert result.name.familyName == "NewLast"
def test_replace_username(self) -> None:
user = _make_user()
result, _ = apply_user_patch([_replace_op("userName", "new@example.com")], user)
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
assert result.userName == "new@example.com"
def test_replace_without_path_uses_dict(self) -> None:
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[
_replace_op(
None,
@@ -102,7 +99,7 @@ class TestApplyUserPatch:
def test_multiple_operations(self) -> None:
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[
_replace_op("active", False),
_replace_op("name.givenName", "Updated"),
@@ -115,7 +112,7 @@ class TestApplyUserPatch:
def test_case_insensitive_path(self) -> None:
user = _make_user()
result, _ = apply_user_patch([_replace_op("Active", False)], user)
result = apply_user_patch([_replace_op("Active", False)], user)
assert result.active is False
def test_original_not_mutated(self) -> None:
@@ -128,22 +125,15 @@ class TestApplyUserPatch:
with pytest.raises(ScimPatchError, match="Unsupported path"):
apply_user_patch([_replace_op("unknownField", "value")], user)
def test_remove_op_clears_field(self) -> None:
"""Remove op should clear the target field (not raise)."""
user = _make_user(externalId="ext-123")
result, _ = apply_user_patch([_remove_op("externalId")], user)
assert result.externalId is None
def test_remove_unsupported_path_raises(self) -> None:
"""Remove op on unsupported path (e.g. 'active') should raise."""
def test_remove_op_on_user_raises(self) -> None:
user = _make_user()
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
with pytest.raises(ScimPatchError, match="Unsupported operation"):
apply_user_patch([_remove_op("active")], user)
def test_replace_without_path_ignores_id(self) -> None:
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
user,
ignored_paths=_OKTA_IGNORED,
@@ -153,7 +143,7 @@ class TestApplyUserPatch:
def test_replace_without_path_ignores_schemas(self) -> None:
"""The 'schemas' key in a value dict should be silently ignored."""
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[
_replace_op(
None,
@@ -171,7 +161,7 @@ class TestApplyUserPatch:
def test_okta_deactivation_payload(self) -> None:
"""Exact Okta deactivation payload: path-less replace with id + active."""
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[
_replace_op(
None,
@@ -186,7 +176,7 @@ class TestApplyUserPatch:
def test_replace_displayname(self) -> None:
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[_replace_op("displayName", "New Display Name")], user
)
assert result.displayName == "New Display Name"
@@ -197,7 +187,7 @@ class TestApplyUserPatch:
"""Okta sends id/schemas/meta alongside actual changes — complex types
(lists, nested dicts) must not cause Pydantic validation errors."""
user = _make_user()
result, _ = apply_user_patch(
result = apply_user_patch(
[
_replace_op(
None,
@@ -217,101 +207,9 @@ class TestApplyUserPatch:
def test_add_operation_works_like_replace(self) -> None:
user = _make_user()
result, _ = apply_user_patch([_add_op("externalId", "ext-456")], user)
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
assert result.externalId == "ext-456"
def test_entra_capitalized_replace_op(self) -> None:
"""Entra ID sends ``"Replace"`` instead of ``"replace"``."""
user = _make_user()
op = ScimPatchOperation(op="Replace", path="active", value=False) # type: ignore[arg-type]
result, _ = apply_user_patch([op], user)
assert result.active is False
def test_entra_capitalized_add_op(self) -> None:
"""Entra ID sends ``"Add"`` instead of ``"add"``."""
user = _make_user()
op = ScimPatchOperation(op="Add", path="externalId", value="ext-999") # type: ignore[arg-type]
result, _ = apply_user_patch([op], user)
assert result.externalId == "ext-999"
def test_entra_enterprise_extension_handled(self) -> None:
"""Entra sends the enterprise extension URN as a key in path-less
PATCH value dicts — enterprise data should be captured in ent_data."""
user = _make_user()
value = ScimPatchResourceValue(active=False)
# Simulate Entra including the enterprise extension URN as extra data
assert value.__pydantic_extra__ is not None
value.__pydantic_extra__[
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
] = {"department": "Engineering"}
result, ent_data = apply_user_patch(
[_replace_op(None, value)],
user,
ignored_paths=_ENTRA_IGNORED,
)
assert result.active is False
assert result.userName == "test@example.com"
assert ent_data["department"] == "Engineering"
def test_okta_handles_enterprise_extension_urn(self) -> None:
"""Enterprise extension URN paths are handled universally, even
for Okta — the data is captured in the enterprise data dict."""
user = _make_user()
value = ScimPatchResourceValue(active=False)
assert value.__pydantic_extra__ is not None
value.__pydantic_extra__[
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
] = {"department": "Engineering"}
result, ent_data = apply_user_patch(
[_replace_op(None, value)],
user,
ignored_paths=_OKTA_IGNORED,
)
assert result.active is False
assert ent_data["department"] == "Engineering"
def test_emails_primary_eq_true_value(self) -> None:
"""emails[primary eq true].value should update the primary email entry."""
user = _make_user(
emails=[ScimEmail(value="old@example.com", type="work", primary=True)]
)
result, _ = apply_user_patch(
[_replace_op("emails[primary eq true].value", "new@example.com")], user
)
# userName should remain unchanged — emails and userName are separate
assert result.userName == "test@example.com"
assert len(result.emails) == 1
assert result.emails[0].value == "new@example.com"
assert result.emails[0].primary is True
def test_enterprise_urn_department_path(self) -> None:
"""Dotted enterprise URN path should set department in ent_data."""
user = _make_user()
_, ent_data = apply_user_patch(
[
_replace_op(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
"Marketing",
)
],
user,
)
assert ent_data["department"] == "Marketing"
def test_enterprise_urn_manager_path(self) -> None:
"""Dotted enterprise URN path for manager should set manager."""
user = _make_user()
_, ent_data = apply_user_patch(
[
_replace_op(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager",
ScimPatchResourceValue.model_validate({"value": "boss-id"}),
)
],
user,
)
assert ent_data["manager"] == "boss-id"
class TestApplyGroupPatch:
"""Tests for SCIM group PATCH operations."""

View File

@@ -2,8 +2,6 @@ from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
@@ -11,10 +9,7 @@ from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.entra import _ENTRA_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.entra import EntraProvider
from ee.onyx.server.scim.providers.okta import OktaProvider
@@ -44,7 +39,9 @@ class TestOktaProvider:
assert OktaProvider().name == "okta"
def test_ignored_patch_paths(self) -> None:
assert OktaProvider().ignored_patch_paths == COMMON_IGNORED_PATCH_PATHS
assert OktaProvider().ignored_patch_paths == frozenset(
{"id", "schemas", "meta"}
)
def test_build_user_resource_basic(self) -> None:
provider = OktaProvider()
@@ -63,12 +60,6 @@ class TestOktaProvider:
meta=ScimMeta(resourceType="User"),
)
def test_build_user_resource_has_core_schema_only(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-123")
assert result.schemas == [SCIM_USER_SCHEMA]
def test_build_user_resource_with_groups(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
@@ -170,42 +161,6 @@ class TestOktaProvider:
assert result.members == []
class TestEntraProvider:
def test_name(self) -> None:
assert EntraProvider().name == "entra"
def test_ignored_patch_paths(self) -> None:
paths = EntraProvider().ignored_patch_paths
assert paths == _ENTRA_IGNORED_PATCH_PATHS
# Enterprise extension URN is now handled (not ignored)
assert paths >= COMMON_IGNORED_PATCH_PATHS
def test_build_user_resource_includes_enterprise_schema(self) -> None:
provider = EntraProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-entra-1")
assert result.schemas == [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
def test_build_user_resource_basic(self) -> None:
provider = EntraProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-entra-1")
assert result == ScimUserResource(
schemas=[SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA],
id=str(user.id),
externalId="ext-entra-1",
userName="test@example.com",
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
displayName="Test User",
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
active=True,
groups=[],
meta=ScimMeta(resourceType="User"),
)
class TestGetDefaultProvider:
def test_returns_okta(self) -> None:
provider = get_default_provider()

View File

@@ -16,7 +16,7 @@ from ee.onyx.server.scim.api import get_user
from ee.onyx.server.scim.api import list_users
from ee.onyx.server.scim.api import patch_user
from ee.onyx.server.scim.api import replace_user
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
@@ -28,8 +28,6 @@ from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_user
from tests.unit.onyx.server.scim.conftest import make_scim_user
from tests.unit.onyx.server.scim.conftest import make_user_mapping
from tests.unit.onyx.server.scim.conftest import parse_scim_list
from tests.unit.onyx.server.scim.conftest import parse_scim_user
class TestListUsers:
@@ -53,9 +51,9 @@ class TestListUsers:
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
assert parsed.totalResults == 0
assert parsed.Resources == []
assert isinstance(result, ScimListResponse)
assert result.totalResults == 0
assert result.Resources == []
def test_returns_users_with_scim_shape(
self,
@@ -79,10 +77,10 @@ class TestListUsers:
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
assert parsed.totalResults == 1
assert len(parsed.Resources) == 1
resource = parsed.Resources[0]
assert isinstance(result, ScimListResponse)
assert result.totalResults == 1
assert len(result.Resources) == 1
resource = result.Resources[0]
assert isinstance(resource, ScimUserResource)
assert resource.userName == "Alice@example.com"
assert resource.externalId == "ext-abc"
@@ -148,9 +146,9 @@ class TestGetUser:
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert resource.userName == "alice@example.com"
assert resource.id == str(user.id)
assert isinstance(result, ScimUserResource)
assert result.userName == "alice@example.com"
assert result.id == str(user.id)
def test_invalid_uuid_returns_404(
self,
@@ -209,8 +207,8 @@ class TestCreateUser:
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert resource.userName == "new@example.com"
assert isinstance(result, ScimUserResource)
assert result.userName == "new@example.com"
mock_dal.add_user.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -316,8 +314,8 @@ class TestCreateUser:
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert resource.externalId == "ext-123"
assert isinstance(result, ScimUserResource)
assert result.externalId == "ext-123"
mock_dal.create_user_mapping.assert_called_once()
@@ -346,7 +344,7 @@ class TestReplaceUser:
db_session=mock_db_session,
)
parse_scim_user(result)
assert isinstance(result, ScimUserResource)
mock_dal.update_user.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -414,15 +412,9 @@ class TestReplaceUser:
db_session=mock_db_session,
)
parse_scim_user(result)
assert isinstance(result, ScimUserResource)
mock_dal.sync_user_external_id.assert_called_once_with(
user.id,
None,
scim_username="test@example.com",
fields=ScimMappingFields(
given_name="Test",
family_name="User",
),
user.id, None, scim_username="test@example.com"
)
@@ -456,7 +448,7 @@ class TestPatchUser:
db_session=mock_db_session,
)
parse_scim_user(result)
assert isinstance(result, ScimUserResource)
mock_dal.update_user.assert_called_once()
def test_not_found_returns_404(
@@ -515,7 +507,7 @@ class TestPatchUser:
db_session=mock_db_session,
)
parse_scim_user(result)
assert isinstance(result, ScimUserResource)
# Verify the update_user call received the new display name
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["personal_name"] == "New Display Name"
@@ -613,12 +605,10 @@ class TestDeleteUser:
class TestScimNameToStr:
"""Tests for _scim_name_to_str helper."""
def test_prefers_formatted_over_components(self) -> None:
"""When client provides formatted, use it — the client knows what it wants."""
name = ScimName(
givenName="Jane", familyName="Smith", formatted="Dr. Jane Smith"
)
assert _scim_name_to_str(name) == "Dr. Jane Smith"
def test_prefers_given_family_over_formatted(self) -> None:
"""Okta may send stale formatted while updating givenName/familyName."""
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
assert _scim_name_to_str(name) == "Jane Smith"
def test_given_name_only(self) -> None:
name = ScimName(givenName="Jane")
@@ -663,9 +653,9 @@ class TestEmailCasePreservation:
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
assert isinstance(result, ScimUserResource)
assert result.userName == "Alice@Example.COM"
assert result.emails[0].value == "Alice@Example.COM"
def test_get_preserves_username_case(
self,
@@ -691,6 +681,6 @@ class TestEmailCasePreservation:
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
assert isinstance(result, ScimUserResource)
assert result.userName == "Alice@Example.COM"
assert result.emails[0].value == "Alice@Example.COM"

View File

@@ -0,0 +1,88 @@
"""Tests for PythonTool availability based on server_enabled flag.
Verifies that PythonTool reports itself as unavailable when either:
- CODE_INTERPRETER_BASE_URL is not set, or
- CodeInterpreterServer.server_enabled is False in the database.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from sqlalchemy.orm import Session
# ------------------------------------------------------------------
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
# ------------------------------------------------------------------
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
None,
)
def test_python_tool_unavailable_without_base_url() -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"",
)
def test_python_tool_unavailable_with_empty_base_url() -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False
# ------------------------------------------------------------------
# Unavailable when server_enabled is False
# ------------------------------------------------------------------
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://localhost:8000",
)
@patch(
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
)
def test_python_tool_unavailable_when_server_disabled(
mock_fetch: MagicMock,
) -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
mock_server = MagicMock()
mock_server.server_enabled = False
mock_fetch.return_value = mock_server
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False
# ------------------------------------------------------------------
# Available when both conditions are met
# ------------------------------------------------------------------
@patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://localhost:8000",
)
@patch(
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
)
def test_python_tool_available_when_server_enabled(
mock_fetch: MagicMock,
) -> None:
from onyx.tools.tool_implementations.python.python_tool import PythonTool
mock_server = MagicMock()
mock_server.server_enabled = True
mock_fetch.return_value = mock_server
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is True

View File

@@ -487,16 +487,7 @@ services:
code-interpreter:
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
entrypoint: ["/bin/bash", "-c"]
command: >
"
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
exec bash ./entrypoint.sh code-interpreter-api;
else
echo 'Skipping code interpreter';
exec tail -f /dev/null;
fi
"
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
restart: unless-stopped
env_file:
- path: .env

View File

@@ -69,6 +69,4 @@ services:
inference_model_server:
profiles: ["inference"]
# Code interpreter is not needed in minimal mode.
code-interpreter:
profiles: ["code-interpreter"]
code-interpreter: {}

View File

@@ -315,16 +315,7 @@ services:
code-interpreter:
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
entrypoint: ["/bin/bash", "-c"]
command: >
"
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
exec bash ./entrypoint.sh code-interpreter-api;
else
echo 'Skipping code interpreter';
exec tail -f /dev/null;
fi
"
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
restart: unless-stopped
env_file:
- path: .env

View File

@@ -352,16 +352,7 @@ services:
code-interpreter:
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
entrypoint: ["/bin/bash", "-c"]
command: >
"
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
exec bash ./entrypoint.sh code-interpreter-api;
else
echo 'Skipping code interpreter';
exec tail -f /dev/null;
fi
"
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
restart: unless-stopped
env_file:
- path: .env

View File

@@ -527,16 +527,7 @@ services:
code-interpreter:
image: onyxdotapp/code-interpreter:${CODE_INTERPRETER_IMAGE_TAG:-latest}
entrypoint: ["/bin/bash", "-c"]
command: >
"
if [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"True\" ] || [ \"$${CODE_INTERPRETER_BETA_ENABLED}\" = \"true\" ]; then
exec bash ./entrypoint.sh code-interpreter-api;
else
echo 'Skipping code interpreter';
exec tail -f /dev/null;
fi
"
command: ["bash", "./entrypoint.sh", "code-interpreter-api"]
restart: unless-stopped
env_file:
- path: .env

View File

@@ -19,6 +19,6 @@ dependencies:
version: 5.4.0
- name: code-interpreter
repository: https://onyx-dot-app.github.io/python-sandbox/
version: 0.2.1
digest: sha256:aedc211d9732c934be8b79735b62f8caa9bcd235e03fd0dd10b49e0a13ed15b7
generated: "2026-02-20T11:19:47.957449-08:00"
version: 0.3.0
digest: sha256:cf8f01906d46034962c6ce894770621ee183ac761e6942951118aeb48540eddd
generated: "2026-02-24T10:59:38.78318-08:00"

View File

@@ -45,6 +45,6 @@ dependencies:
repository: https://charts.min.io/
condition: minio.enabled
- name: code-interpreter
version: 0.2.1
version: 0.3.0
repository: https://onyx-dot-app.github.io/python-sandbox/
condition: codeInterpreter.enabled

View File

@@ -957,7 +957,7 @@ minio:
# Code Interpreter - Python code execution service (beta feature)
codeInterpreter:
enabled: false # Disabled by default (beta feature)
enabled: true
replicaCount: 1

View File

@@ -144,7 +144,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.6.0",
"onyx-devtools==0.6.1",
"openapi-generator-cli==7.17.0",
"pandas-stubs~=2.3.3",
"pre-commit==3.2.2",

View File

@@ -222,6 +222,7 @@ ods run-ci 7353
### `cherry-pick` - Backport Commits to Release Branches
Cherry-pick one or more commits to release branches and automatically create PRs.
Cherry-pick PRs created by this command are labeled `cherry-pick 🍒`.
```shell
ods cherry-pick <commit-sha> [<commit-sha>...] [--release <version>]

View File

@@ -16,6 +16,8 @@ import (
"github.com/onyx-dot-app/onyx/tools/ods/internal/prompt"
)
const cherryPickPRLabel = "cherry-pick 🍒"
// CherryPickOptions holds options for the cherry-pick command
type CherryPickOptions struct {
Releases []string
@@ -510,6 +512,7 @@ func createCherryPickPR(headBranch, baseBranch, title string, commitSHAs, commit
"--head", headBranch,
"--title", title,
"--body", body,
"--label", cherryPickPRLabel,
}
for _, assignee := range assignees {

18
uv.lock generated
View File

@@ -4654,7 +4654,7 @@ requires-dist = [
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.0" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.1" },
{ name = "openai", specifier = "==2.14.0" },
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
@@ -4759,20 +4759,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
[[package]]
name = "onyx-devtools"
version = "0.6.0"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "fastapi" },
{ name = "openapi-generator-cli" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/f9/79d66c1f06e4d1dca0a9df30afcd65ec1a69219fdf17c45349396d1ec668/onyx_devtools-0.6.0-py3-none-any.whl", hash = "sha256:26049075a6d3eb794f44c1bbe55a7cfc0c5427de681ed29319064e2deb956a15", size = 3777572, upload-time = "2026-02-19T23:05:51.823Z" },
{ url = "https://files.pythonhosted.org/packages/40/37/0abff5ab8d79c90f9d57eeaf4998f668145b01e81da0307df56c3b15d16c/onyx_devtools-0.6.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a7c00f2f1924c231b2480edcd3b6aa83398e13e4587c213fe1c97e0f6d3cfce1", size = 3822965, upload-time = "2026-02-19T23:06:02.992Z" },
{ url = "https://files.pythonhosted.org/packages/59/79/a8c23e456b7f1bb4cb741875af6c323fba11d5ef1ba121ea8b44587c236f/onyx_devtools-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0e67fc47dfffb510826a6487dd5029a65b4a5b3f8a42e0e1208b6faee353518c", size = 3570391, upload-time = "2026-02-19T23:05:48.853Z" },
{ url = "https://files.pythonhosted.org/packages/c5/c5/d166bf2c98b80fd83d76abe88e57d63a8cb55880ba40a3d34c831361e3cf/onyx_devtools-0.6.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:0fdbd085f82788b900620424798d04dc1b10c3b1baf9be821ac178adc41c6858", size = 3432611, upload-time = "2026-02-19T23:05:51.924Z" },
{ url = "https://files.pythonhosted.org/packages/18/8e/c53fb7f7781acbf37ca80ebcee5d1274d54c6d853606adefc517df715f9a/onyx_devtools-0.6.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:3915ad5ea245e597a8ad91bd2ba5efc2b6a336ca59c7f3670bd89530cc9ab00f", size = 3777586, upload-time = "2026-02-19T23:05:51.877Z" },
{ url = "https://files.pythonhosted.org/packages/e5/57/194ded4aa5151d96911b021829e015370b4f1fc7493ac584d445fd96f97b/onyx_devtools-0.6.0-py3-none-win_amd64.whl", hash = "sha256:478cdae03ae2e797345396397318446622c7472df0a7d9dbd58d3e96489198b2", size = 3871835, upload-time = "2026-02-19T23:05:51.209Z" },
{ url = "https://files.pythonhosted.org/packages/3c/e9/cc7d204b9b1103b2f33f8f62d29076083f40f44697b398e83b3d44daca23/onyx_devtools-0.6.0-py3-none-win_arm64.whl", hash = "sha256:4bff060fd5f017ddceaf753252e0bc16699922d9a0a88506a56505aad4580824", size = 3492854, upload-time = "2026-02-19T23:05:51.856Z" },
{ url = "https://files.pythonhosted.org/packages/bf/3c/fc0c152ecc403b8d4c929eacc7ea4c3d6cba2094f3cfa51d9e5c4d3bda3d/onyx_devtools-0.6.1-py3-none-any.whl", hash = "sha256:a9ad90ca4536ebe9aaeb604f82c418f3fd148100f14cca7749df0d076ee5c4b0", size = 3781440, upload-time = "2026-02-25T00:59:03.565Z" },
{ url = "https://files.pythonhosted.org/packages/fd/1c/2df5a06eed5490057f0852153940142f9987ff9b865c9c185b733fa360b1/onyx_devtools-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:769a656737e2389312e8e24bf3e9dd559dcb00160f323228dfe34d005ab47af3", size = 3827421, upload-time = "2026-02-25T00:58:59.672Z" },
{ url = "https://files.pythonhosted.org/packages/a2/e3/389644eb9ba0a3cfa975cc015a48140702b05abc9093542b2a3ba6cc5cc1/onyx_devtools-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93886332e97e6efa5f3d7a1d1e4facf1442d301df379f65dfc2a328ed43c8f39", size = 3573060, upload-time = "2026-02-25T00:59:02.582Z" },
{ url = "https://files.pythonhosted.org/packages/68/fe/dd0f32e08f7e7fb1861a28b82431e0a43cf6ab33e04fb2938f4ee20c891b/onyx_devtools-0.6.1-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:cf896e420c78c08c541135473627ffcab0a0156e0e462e71bcb476f560c324fa", size = 3435936, upload-time = "2026-02-25T00:59:02.313Z" },
{ url = "https://files.pythonhosted.org/packages/bb/3a/4376cba6adcf86b9fc55f146493450955497d988920eaa37a8aec9f9f897/onyx_devtools-0.6.1-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:4cb5a1b44a4e74c2fc68164a5caa34bce3f6d2dd5639e48438c1d04f09c4c7c6", size = 3781457, upload-time = "2026-02-25T00:59:02.126Z" },
{ url = "https://files.pythonhosted.org/packages/9d/0d/d2ecf7edc02354d16d9a1d9bd7d8d35f46cdde08b86635ba02075e4d3c7c/onyx_devtools-0.6.1-py3-none-win_amd64.whl", hash = "sha256:0c6c6a667851b9ab215980f1b391216bc2f157c8a29d0cfa96c32c6d10116a5c", size = 3875146, upload-time = "2026-02-25T00:59:02.364Z" },
{ url = "https://files.pythonhosted.org/packages/c5/c3/04783dcfad36b18f48befb6d85bf4f9a9f36fd4cd6e08077676c72c9c504/onyx_devtools-0.6.1-py3-none-win_arm64.whl", hash = "sha256:f095e58b4dad0671c7127a452c5d5f411f55070ebf586a2e47f9193ab753ce44", size = 3496971, upload-time = "2026-02-25T00:59:17.98Z" },
]
[[package]]

View File

@@ -394,7 +394,7 @@
}
.interactive[data-interactive-base-variant="select"][data-disabled] {
@apply bg-transparent;
--interactive-foreground: var(--text-02);
--interactive-foreground: var(--text-01);
}
.interactive[data-interactive-base-variant="select"][data-selected="true"][data-disabled] {
--interactive-foreground: var(--action-link-03);

View File

@@ -119,10 +119,10 @@ function HorizontalInputLayout({
justifyContent="between"
alignItems={center ? "center" : "start"}
>
<div className="flex flex-col self-stretch flex-[2]">
<div className="flex flex-col flex-1 self-stretch">
<TitleLayout {...titleLayoutProps} />
</div>
<div className="flex flex-col flex-[1] items-end">{children}</div>
<div className="flex flex-col items-end">{children}</div>
</Section>
{name && <ErrorLayout name={name} />}
</Section>