forked from github/onyx
feat: sharepoint perm sync (#5033)
* sharepoint perm sync first draft * feat: Implement SharePoint permission synchronization * mypy fix * remove commented code * bot comments fixes and job failure fixes * introduce generic way to upload certificates in credentials * mypy fix * add checkpoiting to sharepoint connector * add sharepoint integration tests * Refactor SharePoint connector to derive tenant domain from verified domains and remove direct tenant domain input from credentials * address review comments * add permission sync to site pages * mypy fix * fix tests error * fix tests and address comments * Update file extraction behavior in SharePoint connector to continue processing on unprocessable files
This commit is contained in:
@@ -102,6 +102,19 @@ TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
#####
|
||||
# SharePoint
|
||||
#####
|
||||
# In seconds, default is 30 minutes
|
||||
SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
####
|
||||
# Celery Job Frequency
|
||||
|
||||
36
backend/ee/onyx/external_permissions/sharepoint/doc_sync.py
Normal file
36
backend/ee/onyx/external_permissions/sharepoint/doc_sync.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SHAREPOINT_DOC_SYNC_TAG = "sharepoint_doc_sync"
|
||||
|
||||
|
||||
def sharepoint_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
sharepoint_connector = SharepointConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.SHAREPOINT,
|
||||
slim_connector=sharepoint_connector,
|
||||
label=SHAREPOINT_DOC_SYNC_TAG,
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def sharepoint_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Sync SharePoint groups and their members"""
|
||||
|
||||
# Get site URLs from connector config
|
||||
connector_config = cc_pair.connector.connector_specific_config
|
||||
|
||||
# Create SharePoint connector instance and load credentials
|
||||
connector = SharepointConnector(**connector_config)
|
||||
connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
if not connector.msal_app:
|
||||
raise RuntimeError("MSAL app not initialized in connector")
|
||||
|
||||
if not connector.sp_tenant_domain:
|
||||
raise RuntimeError("Tenant domain not initialized in connector")
|
||||
|
||||
# Get site descriptors from connector (either configured sites or all sites)
|
||||
site_descriptors = connector.site_descriptors or connector.fetch_sites()
|
||||
|
||||
if not site_descriptors:
|
||||
raise RuntimeError("No SharePoint sites found for group sync")
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
# Process each site
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
# Create client context for the site using connector's MSAL app
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
|
||||
# Yield each group
|
||||
for group in external_groups:
|
||||
logger.debug(
|
||||
f"Found group: {group.id} with {len(group.user_emails)} members"
|
||||
)
|
||||
yield group
|
||||
@@ -0,0 +1,658 @@
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# These values represent different types of SharePoint principals used in permission assignments
|
||||
USER_PRINCIPAL_TYPE = 1 # Individual user accounts
|
||||
ANONYMOUS_USER_PRINCIPAL_TYPE = 3 # Anonymous/unauthenticated users (public access)
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE = 4 # Azure Active Directory security groups
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE = 8 # SharePoint site groups (local to the site)
|
||||
MICROSOFT_DOMAIN = ".onmicrosoft"
|
||||
# Limited Access role type, limited access is a travel through permission not a actual permission
|
||||
LIMITED_ACCESS_ROLE_TYPES = [1, 9]
|
||||
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
|
||||
|
||||
|
||||
class SharepointGroup(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
login_name: str
|
||||
principal_type: int
|
||||
|
||||
|
||||
class GroupsResult(BaseModel):
|
||||
groups_to_emails: dict[str, set[str]]
|
||||
found_public_group: bool
|
||||
|
||||
|
||||
def _get_azuread_group_guid_by_name(
|
||||
graph_client: GraphClient, group_name: str
|
||||
) -> str | None:
|
||||
try:
|
||||
# Search for groups by display name
|
||||
groups = sleep_and_retry(
|
||||
graph_client.groups.filter(f"displayName eq '{group_name}'").get(),
|
||||
"get_azuread_group_guid_by_name",
|
||||
)
|
||||
|
||||
if groups and len(groups) > 0:
|
||||
return groups[0].id
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Azure AD group GUID for name {group_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_guid_from_claims_token(claims_token: str) -> str | None:
|
||||
|
||||
try:
|
||||
# Pattern to match GUID in claims token
|
||||
# Claims tokens often have format: c:0o.c|provider|GUID_suffix
|
||||
guid_pattern = r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})"
|
||||
|
||||
match = re.search(guid_pattern, claims_token, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract GUID from claims token {claims_token}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_group_guid_from_identifier(
|
||||
graph_client: GraphClient, identifier: str
|
||||
) -> str | None:
|
||||
try:
|
||||
# Check if it's already a GUID
|
||||
guid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||
if re.match(guid_pattern, identifier, re.IGNORECASE):
|
||||
return identifier
|
||||
|
||||
# Check if it's a SharePoint claims token
|
||||
if identifier.startswith("c:0") and "|" in identifier:
|
||||
guid = _extract_guid_from_claims_token(identifier)
|
||||
if guid:
|
||||
logger.info(f"Extracted GUID {guid} from claims token {identifier}")
|
||||
return guid
|
||||
|
||||
# Try to search by display name as fallback
|
||||
return _get_azuread_group_guid_by_name(graph_client, identifier)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get group GUID from identifier {identifier}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_security_group_owners(graph_client: GraphClient, group_id: str) -> list[str]:
|
||||
try:
|
||||
# Get group owners using Graph API
|
||||
group = graph_client.groups[group_id]
|
||||
owners = sleep_and_retry(
|
||||
group.owners.get_all(page_loaded=lambda _: None),
|
||||
"get_security_group_owners",
|
||||
)
|
||||
|
||||
owner_emails: list[str] = []
|
||||
logger.info(f"Owners: {owners}")
|
||||
|
||||
for owner in owners:
|
||||
owner_data = owner.to_json()
|
||||
|
||||
# Extract email from the JSON data
|
||||
mail: str | None = owner_data.get("mail")
|
||||
user_principal_name: str | None = owner_data.get("userPrincipalName")
|
||||
|
||||
# Check if owner is a user and has an email
|
||||
if mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
owner_emails.append(mail)
|
||||
elif user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
owner_emails.append(user_principal_name)
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(owner_emails)} owners from security group {group_id}"
|
||||
)
|
||||
return owner_emails
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get security group owners for group {group_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
|
||||
|
||||
try:
|
||||
# First try to get the list item directly from the drive item
|
||||
if hasattr(drive_item, "listItem"):
|
||||
list_item = drive_item.listItem
|
||||
if list_item:
|
||||
# Load the list item properties to get the ID
|
||||
sleep_and_retry(list_item.get(), "get_sharepoint_list_item_id")
|
||||
if hasattr(list_item, "id") and list_item.id:
|
||||
return str(list_item.id)
|
||||
|
||||
# The SharePoint list item ID is typically available in the sharepointIds property
|
||||
sharepoint_ids = getattr(drive_item, "sharepoint_ids", None)
|
||||
if sharepoint_ids and hasattr(sharepoint_ids, "listItemId"):
|
||||
return sharepoint_ids.listItemId
|
||||
|
||||
# Alternative: try to get it from the properties
|
||||
properties = getattr(drive_item, "properties", None)
|
||||
if properties:
|
||||
# Sometimes the SharePoint list item ID is in the properties
|
||||
for prop_name, prop_value in properties.items():
|
||||
if "listitemid" in prop_name.lower():
|
||||
return str(prop_value)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting SharePoint list item ID for item {drive_item.id}: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def _is_public_item(drive_item: DriveItem) -> bool:
|
||||
is_public = False
|
||||
try:
|
||||
permissions = sleep_and_retry(
|
||||
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
|
||||
)
|
||||
for permission in permissions:
|
||||
if permission.link and (
|
||||
permission.link.scope == "anonymous"
|
||||
or permission.link.scope == "organization"
|
||||
):
|
||||
is_public = True
|
||||
break
|
||||
return is_public
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_public_login_name(login_name: str) -> bool:
|
||||
# Patterns that indicate public access
|
||||
# This list is derived from the below link
|
||||
# https://learn.microsoft.com/en-us/answers/questions/2085339/guid-in-the-loginname-of-site-user-everyone-except
|
||||
public_login_patterns: list[str] = [
|
||||
"c:0-.f|rolemanager|spo-grid-all-users/",
|
||||
"c:0(.s|true",
|
||||
]
|
||||
for pattern in public_login_patterns:
|
||||
if pattern in login_name:
|
||||
logger.info(f"Login name {login_name} is public")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
def _get_group_name_with_suffix(
|
||||
login_name: str, group_name: str, graph_client: GraphClient
|
||||
) -> str:
|
||||
ad_group_suffix = _get_group_guid_from_identifier(graph_client, login_name)
|
||||
return f"{group_name}_{ad_group_suffix}"
|
||||
|
||||
|
||||
def _get_sharepoint_groups(
|
||||
client_context: ClientContext, group_name: str, graph_client: GraphClient
|
||||
) -> tuple[set[SharepointGroup], set[str]]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
user_emails: set[str] = set()
|
||||
|
||||
def process_users(users: list[Any]) -> None:
|
||||
nonlocal groups, user_emails
|
||||
|
||||
for user in users:
|
||||
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
|
||||
user, "user_principal_name"
|
||||
):
|
||||
if user.user_principal_name:
|
||||
email = user.user_principal_name
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"User don't have a user principal name: {user.login_name}"
|
||||
)
|
||||
elif user.principal_type in [
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
|
||||
]:
|
||||
name = user.title
|
||||
if user.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
name = _get_group_name_with_suffix(
|
||||
user.login_name, name, graph_client
|
||||
)
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=user.login_name,
|
||||
principal_type=user.principal_type,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
group = client_context.web.site_groups.get_by_name(group_name)
|
||||
sleep_and_retry(
|
||||
group.users.get_all(page_loaded=process_users), "get_sharepoint_groups"
|
||||
)
|
||||
|
||||
return groups, user_emails
|
||||
|
||||
|
||||
def _get_azuread_groups(
|
||||
graph_client: GraphClient, group_name: str
|
||||
) -> tuple[set[SharepointGroup], set[str]]:
|
||||
|
||||
group_id = _get_group_guid_from_identifier(graph_client, group_name)
|
||||
if not group_id:
|
||||
logger.error(f"Failed to get Azure AD group GUID for name {group_name}")
|
||||
return set(), set()
|
||||
group = graph_client.groups[group_id]
|
||||
groups: set[SharepointGroup] = set()
|
||||
user_emails: set[str] = set()
|
||||
|
||||
def process_members(members: list[Any]) -> None:
|
||||
nonlocal groups, user_emails
|
||||
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
|
||||
# Check for user-specific attributes
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
display_name = member_data.get("displayName") or member_data.get(
|
||||
"display_name"
|
||||
)
|
||||
|
||||
# Check object attributes directly (if available)
|
||||
is_user = False
|
||||
is_group = False
|
||||
|
||||
# Users typically have userPrincipalName or mail
|
||||
if user_principal_name or (mail and "@" in str(mail)):
|
||||
is_user = True
|
||||
# Groups typically have displayName but no userPrincipalName
|
||||
elif display_name and not user_principal_name:
|
||||
# Additional check: try to access group-specific properties
|
||||
if (
|
||||
hasattr(member, "groupTypes")
|
||||
or member_data.get("groupTypes") is not None
|
||||
):
|
||||
is_group = True
|
||||
# Or check if it has an 'id' field typical for groups
|
||||
elif member_data.get("id") and not user_principal_name:
|
||||
is_group = True
|
||||
|
||||
# Check the object type name (fallback)
|
||||
if not is_user and not is_group:
|
||||
obj_type = type(member).__name__.lower()
|
||||
if "user" in obj_type:
|
||||
is_user = True
|
||||
elif "group" in obj_type:
|
||||
is_group = True
|
||||
|
||||
# Process based on identification
|
||||
if is_user:
|
||||
if user_principal_name:
|
||||
email = user_principal_name
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
elif mail:
|
||||
email = mail
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
logger.info(f"Added user: {user_principal_name or mail}")
|
||||
elif is_group:
|
||||
if not display_name:
|
||||
logger.error(f"No display name for group: {member_data.get('id')}")
|
||||
continue
|
||||
name = _get_group_name_with_suffix(
|
||||
member_data.get("id", ""), display_name, graph_client
|
||||
)
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=member_data.get("id", ""), # Use ID for groups
|
||||
principal_type=AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
logger.info(f"Added group: {name}")
|
||||
else:
|
||||
# Log unidentified members for debugging
|
||||
logger.warning(f"Could not identify member type for: {member_data}")
|
||||
|
||||
sleep_and_retry(
|
||||
group.members.get_all(page_loaded=process_members), "get_azuread_groups"
|
||||
)
|
||||
|
||||
owner_emails = _get_security_group_owners(graph_client, group_id)
|
||||
user_emails.update(owner_emails)
|
||||
|
||||
return groups, user_emails
|
||||
|
||||
|
||||
def _get_groups_and_members_recursively(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
groups: set[SharepointGroup],
|
||||
) -> GroupsResult:
|
||||
"""
|
||||
Get all groups and their members recursively.
|
||||
"""
|
||||
group_queue: deque[SharepointGroup] = deque(groups)
|
||||
visited_groups: set[str] = set()
|
||||
visited_group_name_to_emails: dict[str, set[str]] = {}
|
||||
while group_queue:
|
||||
group = group_queue.popleft()
|
||||
if group.login_name in visited_groups:
|
||||
continue
|
||||
visited_groups.add(group.login_name)
|
||||
visited_group_name_to_emails[group.name] = set()
|
||||
logger.info(
|
||||
f"Processing group: {group.name} principal type: {group.principal_type}"
|
||||
)
|
||||
if group.principal_type == SHAREPOINT_GROUP_PRINCIPAL_TYPE:
|
||||
group_info, user_emails = _get_sharepoint_groups(
|
||||
client_context, group.login_name, graph_client
|
||||
)
|
||||
visited_group_name_to_emails[group.name].update(user_emails)
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
# if the site is public, we have default groups assigned to it, so we return early
|
||||
if _is_public_login_name(group.login_name):
|
||||
return GroupsResult(groups_to_emails={}, found_public_group=True)
|
||||
|
||||
group_info, user_emails = _get_azuread_groups(
|
||||
graph_client, group.login_name
|
||||
)
|
||||
visited_group_name_to_emails[group.name].update(user_emails)
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
|
||||
return GroupsResult(
|
||||
groups_to_emails=visited_group_name_to_emails, found_public_group=False
|
||||
)
|
||||
|
||||
|
||||
def get_external_access_from_sharepoint(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
drive_name: str | None,
|
||||
drive_item: DriveItem | None,
|
||||
site_page: dict[str, Any] | None,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get external access information from SharePoint.
|
||||
"""
|
||||
groups: set[SharepointGroup] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_ids: set[str] = set()
|
||||
|
||||
# Add all members to a processing set first
|
||||
def add_user_and_group_to_sets(
|
||||
role_assignments: RoleAssignmentCollection,
|
||||
) -> None:
|
||||
nonlocal user_emails, groups
|
||||
for assignment in role_assignments:
|
||||
if assignment.role_definition_bindings:
|
||||
is_limited_access = True
|
||||
for role_definition_binding in assignment.role_definition_bindings:
|
||||
if (
|
||||
role_definition_binding.role_type_kind
|
||||
not in LIMITED_ACCESS_ROLE_TYPES
|
||||
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
|
||||
):
|
||||
is_limited_access = False
|
||||
break
|
||||
|
||||
# Skip if the role is only Limited Access, because this is not a actual permission its a travel through permission
|
||||
if is_limited_access:
|
||||
logger.info(
|
||||
"Skipping assignment because it has only Limited Access role"
|
||||
)
|
||||
continue
|
||||
if assignment.member:
|
||||
member = assignment.member
|
||||
if member.principal_type == USER_PRINCIPAL_TYPE and hasattr(
|
||||
member, "user_principal_name"
|
||||
):
|
||||
email = member.user_principal_name
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
elif member.principal_type in [
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
|
||||
]:
|
||||
name = member.title
|
||||
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
name = _get_group_name_with_suffix(
|
||||
member.login_name, name, graph_client
|
||||
)
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=member.login_name,
|
||||
principal_type=member.principal_type,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
if drive_item and drive_name:
|
||||
# Here we check if the item have have any public links, if so we return early
|
||||
is_public = _is_public_item(drive_item)
|
||||
if is_public:
|
||||
logger.info(f"Item {drive_item.id} is public")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
item_id = _get_sharepoint_list_item_id(drive_item)
|
||||
|
||||
if not item_id:
|
||||
raise RuntimeError(
|
||||
f"Failed to get SharePoint list item ID for item {drive_item.id}"
|
||||
)
|
||||
|
||||
if drive_name == "Shared Documents":
|
||||
drive_name = "Documents"
|
||||
|
||||
item = client_context.web.lists.get_by_title(drive_name).items.get_by_id(
|
||||
item_id
|
||||
)
|
||||
|
||||
sleep_and_retry(
|
||||
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
|
||||
page_loaded=add_user_and_group_to_sets,
|
||||
),
|
||||
"get_external_access_from_sharepoint",
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
site_pages = client_context.web.lists.get_by_title("Site Pages")
|
||||
client_context.load(site_pages)
|
||||
client_context.execute_query()
|
||||
site_pages.items.get_by_url(site_url).role_assignments.expand(
|
||||
["Member", "RoleDefinitionBindings"]
|
||||
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
|
||||
else:
|
||||
raise RuntimeError("No drive item or site page provided")
|
||||
|
||||
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
|
||||
client_context, graph_client, groups
|
||||
)
|
||||
|
||||
# If the site is public, w have default groups assigned to it, so we return early
|
||||
if groups_and_members.found_public_group:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
for group_name, _ in groups_and_members.groups_to_emails.items():
|
||||
if add_prefix:
|
||||
group_name = build_ext_group_name_for_onyx(
|
||||
group_name, DocumentSource.SHAREPOINT
|
||||
)
|
||||
group_ids.add(group_name.lower())
|
||||
|
||||
logger.info(f"User emails: {len(user_emails)}")
|
||||
logger.info(f"Group IDs: {len(group_ids)}")
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def get_sharepoint_external_groups(
|
||||
client_context: ClientContext, graph_client: GraphClient
|
||||
) -> list[ExternalUserGroup]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
|
||||
def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
|
||||
nonlocal groups
|
||||
for assignment in role_assignments:
|
||||
if assignment.role_definition_bindings:
|
||||
is_limited_access = True
|
||||
for role_definition_binding in assignment.role_definition_bindings:
|
||||
if (
|
||||
role_definition_binding.role_type_kind
|
||||
not in LIMITED_ACCESS_ROLE_TYPES
|
||||
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
|
||||
):
|
||||
is_limited_access = False
|
||||
break
|
||||
|
||||
# Skip if the role assignment is only Limited Access, because this is not a actual permission its
|
||||
# a travel through permission
|
||||
if is_limited_access:
|
||||
logger.info(
|
||||
"Skipping assignment because it has only Limited Access role"
|
||||
)
|
||||
continue
|
||||
if assignment.member:
|
||||
member = assignment.member
|
||||
if member.principal_type in [
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
|
||||
]:
|
||||
name = member.title
|
||||
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
name = _get_group_name_with_suffix(
|
||||
member.login_name, name, graph_client
|
||||
)
|
||||
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=member.login_name,
|
||||
principal_type=member.principal_type,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
sleep_and_retry(
|
||||
client_context.web.role_assignments.expand(
|
||||
["Member", "RoleDefinitionBindings"]
|
||||
).get_all(page_loaded=add_group_to_sets),
|
||||
"get_sharepoint_external_groups",
|
||||
)
|
||||
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
|
||||
client_context, graph_client, groups
|
||||
)
|
||||
|
||||
# We don't have any direct way to check if the site is public, so we check if any public group is present
|
||||
if groups_and_members.found_public_group:
|
||||
return []
|
||||
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
graph_client.groups.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups",
|
||||
)
|
||||
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
|
||||
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
|
||||
ad_groups_to_emails: dict[str, set[str]] = {}
|
||||
for group in azure_ad_groups:
|
||||
# If the group is already identified, we don't need to get the members
|
||||
if group.display_name in identified_groups:
|
||||
continue
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
name = group.display_name
|
||||
name = _get_group_name_with_suffix(group.id, name, graph_client)
|
||||
|
||||
members = sleep_and_retry(
|
||||
group.members.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
|
||||
)
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
if not ad_groups_to_emails.get(name):
|
||||
ad_groups_to_emails[name] = set()
|
||||
if user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
ad_groups_to_emails[name].add(user_principal_name)
|
||||
elif mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
ad_groups_to_emails[name].add(mail)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
for group_name, emails in ad_groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
return external_user_groups
|
||||
@@ -11,6 +11,8 @@ from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
@@ -29,6 +31,8 @@ from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.doc_sync import sharepoint_doc_sync
|
||||
from ee.onyx.external_permissions.sharepoint.group_sync import sharepoint_group_sync
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -156,6 +160,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
DocumentSource.SHAREPOINT: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=sharepoint_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=sharepoint_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=False,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
38
backend/onyx/connectors/sharepoint/connector_utils.py
Normal file
38
backend/onyx/connectors/sharepoint/connector_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from onyx.connectors.models import ExternalAccess
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def get_sharepoint_external_access(
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
drive_item: DriveItem | None = None,
|
||||
drive_name: str | None = None,
|
||||
site_page: dict[str, Any] | None = None,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
if drive_item and drive_item.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
|
||||
# Get external access using the EE implementation
|
||||
def noop_fallback(*args: Any, **kwargs: Any) -> ExternalAccess:
|
||||
return ExternalAccess.empty()
|
||||
|
||||
get_external_access_func = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.external_permissions.sharepoint.permission_utils",
|
||||
"get_external_access_from_sharepoint",
|
||||
fallback=noop_fallback,
|
||||
)
|
||||
|
||||
external_access = get_external_access_func(
|
||||
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
|
||||
)
|
||||
|
||||
return external_access
|
||||
@@ -1,7 +1,12 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -27,6 +32,9 @@ from onyx.server.documents.models import CredentialDataUpdateRequest
|
||||
from onyx.server.documents.models import CredentialSnapshot
|
||||
from onyx.server.documents.models import CredentialSwapRequest
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.documents.private_key_types import FILE_TYPE_TO_FILE_PROCESSOR
|
||||
from onyx.server.documents.private_key_types import PrivateKeyFileTypes
|
||||
from onyx.server.documents.private_key_types import ProcessPrivateKeyFileProtocol
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -76,6 +84,7 @@ def get_cc_source_full_info(
|
||||
document_source=source_type,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
@@ -149,6 +158,70 @@ def create_credential_from_model(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/credential/private-key")
|
||||
def create_credential_with_private_key(
|
||||
credential_json: str = Form(...),
|
||||
admin_public: bool = Form(False),
|
||||
curator_public: bool = Form(False),
|
||||
groups: list[int] = Form([]),
|
||||
name: str | None = Form(None),
|
||||
source: str = Form(...),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
uploaded_file: UploadFile = File(...),
|
||||
field_key: str = Form(...),
|
||||
type_definition_key: str = Form(...),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
try:
|
||||
credential_data = json.loads(credential_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid JSON in credential_json: {str(e)}",
|
||||
)
|
||||
|
||||
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
|
||||
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
|
||||
)
|
||||
if private_key_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid type definition key for private key file",
|
||||
)
|
||||
private_key_content: str = private_key_processor(uploaded_file)
|
||||
|
||||
credential_data[field_key] = private_key_content
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_data,
|
||||
admin_public=admin_public,
|
||||
curator_public=curator_public,
|
||||
groups=groups,
|
||||
name=name,
|
||||
source=DocumentSource(source),
|
||||
)
|
||||
|
||||
if not _ignore_credential_permissions(DocumentSource(source)):
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=groups,
|
||||
object_is_public=curator_public,
|
||||
)
|
||||
|
||||
# Temporary fix for empty Google App credentials
|
||||
if DocumentSource(source) == DocumentSource.GMAIL:
|
||||
cleanup_gmail_credentials(db_session=db_session)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(
|
||||
id=credential.id,
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@@ -209,6 +282,53 @@ def update_credential_data(
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.put("/admin/credential/private-key/{credential_id}")
|
||||
def update_credential_private_key(
|
||||
credential_id: int,
|
||||
name: str = Form(...),
|
||||
credential_json: str = Form(...),
|
||||
uploaded_file: UploadFile = File(...),
|
||||
field_key: str = Form(...),
|
||||
type_definition_key: str = Form(...),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialBase:
|
||||
try:
|
||||
credential_data = json.loads(credential_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid JSON in credential_json: {str(e)}",
|
||||
)
|
||||
|
||||
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
|
||||
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
|
||||
)
|
||||
if private_key_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid type definition key for private key file",
|
||||
)
|
||||
private_key_content: str = private_key_processor(uploaded_file)
|
||||
credential_data[field_key] = private_key_content
|
||||
|
||||
credential = alter_credential(
|
||||
credential_id,
|
||||
name,
|
||||
credential_data,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.patch("/credential/{credential_id}")
|
||||
def update_credential_from_model(
|
||||
credential_id: int,
|
||||
|
||||
75
backend/onyx/server/documents/document_utils.py
Normal file
75
backend/onyx/server/documents/document_utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_password_related_error(error: Exception) -> bool:
|
||||
"""
|
||||
Check if the exception indicates a password-related issue rather than a format issue.
|
||||
"""
|
||||
error_msg = str(error).lower()
|
||||
password_keywords = ["mac", "integrity", "password", "authentication", "verify"]
|
||||
return any(keyword in error_msg for keyword in password_keywords)
|
||||
|
||||
|
||||
def validate_pkcs12_content(file_bytes: bytes) -> bool:
|
||||
"""
|
||||
Validate that the file content is actually a PKCS#12 file.
|
||||
This performs basic format validation without requiring passwords.
|
||||
"""
|
||||
try:
|
||||
# Basic file size check
|
||||
if len(file_bytes) < 10:
|
||||
logger.debug("File too small to be a valid PKCS#12 file")
|
||||
return False
|
||||
|
||||
# Check for PKCS#12 magic bytes/ASN.1 structure
|
||||
# PKCS#12 files start with ASN.1 SEQUENCE tag (0x30)
|
||||
if file_bytes[0] != 0x30:
|
||||
logger.debug("File does not start with ASN.1 SEQUENCE tag")
|
||||
return False
|
||||
|
||||
# Try to parse the outer ASN.1 structure without password validation
|
||||
# This checks if the file has the basic PKCS#12 structure
|
||||
try:
|
||||
# Attempt to load just to validate the basic format
|
||||
# We expect this to fail due to password, but it should fail with a specific error
|
||||
pkcs12.load_key_and_certificates(file_bytes, password=None)
|
||||
return True
|
||||
except ValueError as e:
|
||||
# Check if the error is related to password (expected) vs format issues
|
||||
if _is_password_related_error(e):
|
||||
# These errors indicate the file format is correct but password is wrong/missing
|
||||
logger.debug(
|
||||
f"PKCS#12 format appears valid, password-related error: {e}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Other ValueError likely indicates format issues
|
||||
logger.debug(f"PKCS#12 format validation failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
# Try with empty password as fallback
|
||||
try:
|
||||
pkcs12.load_key_and_certificates(file_bytes, password=b"")
|
||||
return True
|
||||
except ValueError as e2:
|
||||
if _is_password_related_error(e2):
|
||||
logger.debug(
|
||||
f"PKCS#12 format appears valid with empty password attempt: {e2}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.debug(
|
||||
f"PKCS#12 validation failed on both attempts: {e}, {e2}"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
logger.debug(f"PKCS#12 validation failed: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error during PKCS#12 validation: {e}")
|
||||
return False
|
||||
57
backend/onyx/server/documents/private_key_types.py
Normal file
57
backend/onyx/server/documents/private_key_types.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.server.documents.document_utils import validate_pkcs12_content
|
||||
|
||||
|
||||
class ProcessPrivateKeyFileProtocol(Protocol):
|
||||
def __call__(self, file: UploadFile) -> str:
|
||||
"""
|
||||
Accepts a file-like object, validates the file (e.g., checks extension and content),
|
||||
and returns its contents as a base64-encoded string if valid.
|
||||
Raises an exception if validation fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PrivateKeyFileTypes(Enum):
|
||||
SHAREPOINT_PFX_FILE = "sharepoint_pfx_file"
|
||||
|
||||
|
||||
def process_sharepoint_private_key_file(file: UploadFile) -> str:
|
||||
"""
|
||||
Process and validate a private key file upload.
|
||||
|
||||
Validates both the file extension and file content to ensure it's a valid PKCS#12 file.
|
||||
Content validation prevents attacks that rely on file extension spoofing.
|
||||
"""
|
||||
# First check file extension (basic filter)
|
||||
if not (file.filename and file.filename.lower().endswith(".pfx")):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid file type. Only .pfx files are supported."
|
||||
)
|
||||
|
||||
# Read file content for validation and processing
|
||||
private_key_bytes = file.file.read()
|
||||
|
||||
# Validate file content to prevent extension spoofing attacks
|
||||
if not validate_pkcs12_content(private_key_bytes):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
|
||||
)
|
||||
|
||||
# Convert to base64 if validation passes
|
||||
pfx_64 = base64.b64encode(private_key_bytes).decode("ascii")
|
||||
return pfx_64
|
||||
|
||||
|
||||
FILE_TYPE_TO_FILE_PROCESSOR: dict[
|
||||
PrivateKeyFileTypes, ProcessPrivateKeyFileProtocol
|
||||
] = {
|
||||
PrivateKeyFileTypes.SHAREPOINT_PFX_FILE: process_sharepoint_private_key_file,
|
||||
}
|
||||
@@ -1,14 +1,18 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,6 +80,17 @@ def find_document(documents: list[Document], semantic_identifier: str) -> Docume
|
||||
return matching_docs[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store_image() -> MagicMock:
|
||||
"""Mock store_image_and_create_section to return a predefined ImageSection."""
|
||||
mock = MagicMock()
|
||||
mock.return_value = (
|
||||
ImageSection(image_file_id="mocked-file-id", link="https://example.com/image"),
|
||||
"mocked-file-id",
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_credentials() -> dict[str, str]:
|
||||
return {
|
||||
@@ -87,175 +102,219 @@ def sharepoint_credentials() -> dict[str, str]:
|
||||
|
||||
def test_sharepoint_connector_all_sites__docs_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with no sites
|
||||
connector = SharepointConnector(include_site_pages=False)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with no sites
|
||||
connector = SharepointConnector(include_site_pages=False)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Not asserting expected sites because that can change in test tenant at any time
|
||||
# Finding any docs is good enough to verify that the connector is working
|
||||
document_batches = list(connector.load_from_state())
|
||||
assert document_batches, "Should find documents from all sites"
|
||||
# Not asserting expected sites because that can change in test tenant at any time
|
||||
# Finding any docs is good enough to verify that the connector is working
|
||||
document_batches = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert document_batches, "Should find documents from all sites"
|
||||
|
||||
|
||||
def test_sharepoint_connector_specific_folder(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the test site URL and specific folder
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the test site URL and specific folder
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get all documents
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
# Should only find documents in the test folder
|
||||
test_folder_docs = [
|
||||
doc
|
||||
for doc in EXPECTED_DOCUMENTS
|
||||
if doc.folder_path and doc.folder_path.startswith("test")
|
||||
]
|
||||
assert len(found_documents) == len(
|
||||
test_folder_docs
|
||||
), "Should only find documents in test folder"
|
||||
# Should only find documents in the test folder
|
||||
test_folder_docs = [
|
||||
doc
|
||||
for doc in EXPECTED_DOCUMENTS
|
||||
if doc.folder_path and doc.folder_path.startswith("test")
|
||||
]
|
||||
assert len(found_documents) == len(
|
||||
test_folder_docs
|
||||
), "Should only find documents in test folder"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in test_folder_docs:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
# Verify each expected document
|
||||
for expected in test_folder_docs:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_root_folder__docs_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get all documents
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_DOCUMENTS
|
||||
), "Should find all documents in main library"
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_DOCUMENTS
|
||||
), "Should find all documents in main library"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in EXPECTED_DOCUMENTS:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
# Verify each expected document
|
||||
for expected in EXPECTED_DOCUMENTS:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_other_library(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the other library
|
||||
connector = SharepointConnector(
|
||||
sites=[
|
||||
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the other library
|
||||
connector = SharepointConnector(
|
||||
sites=[
|
||||
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||
]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
expected_documents: list[ExpectedDocument] = [
|
||||
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||
]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Should find all documents in `Other Library`
|
||||
assert len(found_documents) == len(
|
||||
expected_documents
|
||||
), "Should find all documents in `Other Library`"
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
expected_documents: list[ExpectedDocument] = [
|
||||
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||
]
|
||||
|
||||
# Should find all documents in `Other Library`
|
||||
assert len(found_documents) == len(
|
||||
expected_documents
|
||||
), "Should find all documents in `Other Library`"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
# Verify each expected document
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_poll(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
|
||||
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
|
||||
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
|
||||
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
|
||||
start = datetime(
|
||||
2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc
|
||||
) # 12 seconds before
|
||||
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get documents within the time window
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=start.timestamp(),
|
||||
end=end.timestamp(),
|
||||
)
|
||||
|
||||
# Should only find test1.docx
|
||||
assert len(found_documents) == 1, "Should only find one document in the time window"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "test1.docx"
|
||||
verify_document_metadata(doc)
|
||||
verify_document_content(
|
||||
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
|
||||
)
|
||||
# Should only find test1.docx
|
||||
assert (
|
||||
len(found_documents) == 1
|
||||
), "Should only find one document in the time window"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "test1.docx"
|
||||
verify_document_metadata(doc)
|
||||
verify_document_content(
|
||||
doc,
|
||||
[d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0],
|
||||
)
|
||||
|
||||
|
||||
def test_sharepoint_connector_pages(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get documents within the time window
|
||||
found_documents = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
# Should only find CollabHome
|
||||
assert len(found_documents) == 1, "Should only find one page"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "CollabHome"
|
||||
verify_document_metadata(doc)
|
||||
assert len(doc.sections) == 1
|
||||
assert (
|
||||
doc.sections[0].text
|
||||
== """
|
||||
# Should only find CollabHome
|
||||
assert len(found_documents) == 1, "Should only find one page"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "CollabHome"
|
||||
verify_document_metadata(doc)
|
||||
assert len(doc.sections) == 1
|
||||
assert (
|
||||
doc.sections[0].text
|
||||
== """
|
||||
# Home
|
||||
|
||||
Display recent news.
|
||||
@@ -282,4 +341,4 @@ Add a document library
|
||||
|
||||
## Document library
|
||||
""".strip()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
SharepointTestEnvSetupTuple = tuple[
|
||||
DATestUser, # admin_user
|
||||
DATestUser, # regular_user_1
|
||||
DATestUser, # regular_user_2
|
||||
DATestCredential,
|
||||
DATestConnector,
|
||||
DATestCCPair,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sharepoint_test_env_setup() -> Generator[SharepointTestEnvSetupTuple]:
|
||||
# Reset all data before running the test
|
||||
reset_all()
|
||||
# Required environment variables for SharePoint certificate authentication
|
||||
sp_client_id = os.environ.get("PERM_SYNC_SHAREPOINT_CLIENT_ID")
|
||||
sp_private_key = os.environ.get("PERM_SYNC_SHAREPOINT_PRIVATE_KEY")
|
||||
sp_certificate_password = os.environ.get(
|
||||
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
|
||||
)
|
||||
sp_directory_id = os.environ.get("PERM_SYNC_SHAREPOINT_DIRECTORY_ID")
|
||||
sharepoint_sites = "https://danswerai.sharepoint.com/sites/Permisisonsync"
|
||||
admin_email = "admin@onyx.app"
|
||||
user1_email = "subash@onyx.app"
|
||||
user2_email = "raunak@onyx.app"
|
||||
|
||||
if not sp_private_key or not sp_certificate_password or not sp_directory_id:
|
||||
pytest.skip("Skipping test because required environment variables are not set")
|
||||
|
||||
# Certificate-based credentials
|
||||
credentials = {
|
||||
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
|
||||
"sp_client_id": sp_client_id,
|
||||
"sp_private_key": sp_private_key,
|
||||
"sp_certificate_password": sp_certificate_password,
|
||||
"sp_directory_id": sp_directory_id,
|
||||
}
|
||||
|
||||
# Create users
|
||||
admin_user: DATestUser = UserManager.create(email=admin_email)
|
||||
regular_user_1: DATestUser = UserManager.create(email=user1_email)
|
||||
regular_user_2: DATestUser = UserManager.create(email=user2_email)
|
||||
|
||||
# Create LLM provider for search functionality
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# Create credential
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
credential_json=credentials,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create connector with SharePoint-specific configuration
|
||||
connector: DATestConnector = ConnectorManager.create(
|
||||
name="SharePoint Test",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
connector_specific_config={
|
||||
"sites": sharepoint_sites.split(","),
|
||||
},
|
||||
access_type=AccessType.SYNC, # Enable permission sync
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create CC pair with permission sync enabled
|
||||
cc_pair: DATestCCPair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC, # Enable permission sync
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for both indexing and permission sync to complete
|
||||
before = datetime.now(tz=timezone.utc)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
timeout=float("inf"),
|
||||
)
|
||||
|
||||
# Wait for permission sync completion specifically
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
timeout=float("inf"),
|
||||
)
|
||||
|
||||
yield admin_user, regular_user_1, regular_user_2, credential, connector, cc_pair
|
||||
@@ -0,0 +1,205 @@
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.access.access import _get_access_for_documents
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.connector_job_tests.sharepoint.conftest import (
|
||||
SharepointTestEnvSetupTuple,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_user_acl(user: User, db_session: Session) -> set[str]:
|
||||
db_external_groups = (
|
||||
fetch_external_groups_for_user(db_session, user.id) if user else []
|
||||
)
|
||||
prefixed_external_groups = [
|
||||
prefix_external_group(db_external_group.external_user_group_id)
|
||||
for db_external_group in db_external_groups
|
||||
]
|
||||
|
||||
user_acl = set(prefixed_external_groups)
|
||||
user_acl.update({prefix_user_email(user.email), PUBLIC_DOC_PAT})
|
||||
return user_acl
|
||||
|
||||
|
||||
def get_user_document_access_via_acl(
|
||||
test_user: DATestUser, document_ids: List[str], db_session: Session
|
||||
) -> List[str]:
|
||||
|
||||
# Get the actual User object from the database
|
||||
user = fetch_user_by_id(db_session, UUID(test_user.id))
|
||||
if not user:
|
||||
logger.error(f"Could not find user with ID {test_user.id}")
|
||||
return []
|
||||
|
||||
user_acl = get_user_acl(user, db_session)
|
||||
logger.info(f"User {user.email} ACL entries: {user_acl}")
|
||||
|
||||
# Get document access information
|
||||
doc_access_map = _get_access_for_documents(document_ids, db_session)
|
||||
logger.info(f"Found access info for {len(doc_access_map)} documents")
|
||||
|
||||
accessible_docs = []
|
||||
for doc_id, doc_access in doc_access_map.items():
|
||||
doc_acl = doc_access.to_acl()
|
||||
logger.info(f"Document {doc_id} ACL: {doc_acl}")
|
||||
|
||||
# Check if user has any matching ACL entry
|
||||
if user_acl.intersection(doc_acl):
|
||||
accessible_docs.append(doc_id)
|
||||
logger.info(f"User {user.email} has access to document {doc_id}")
|
||||
else:
|
||||
logger.info(f"User {user.email} does NOT have access to document {doc_id}")
|
||||
|
||||
return accessible_docs
|
||||
|
||||
|
||||
def get_all_connector_documents(
|
||||
cc_pair: DATestCCPair, db_session: Session
|
||||
) -> List[str]:
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
document_ids = [row[0] for row in result.fetchall()]
|
||||
logger.info(
|
||||
f"Found {len(document_ids)} documents for connector {cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
return document_ids
|
||||
|
||||
|
||||
def get_documents_by_permission_type(
|
||||
document_ids: List[str], db_session: Session
|
||||
) -> List[str]:
|
||||
"""
|
||||
Categorize documents by their permission types
|
||||
Returns a dictionary with lists of document IDs for each permission type
|
||||
"""
|
||||
doc_access_map = _get_access_for_documents(document_ids, db_session)
|
||||
|
||||
public_docs = []
|
||||
|
||||
for doc_id, doc_access in doc_access_map.items():
|
||||
if doc_access.is_public:
|
||||
public_docs.append(doc_id)
|
||||
|
||||
return public_docs
|
||||
|
||||
|
||||
def test_public_documents_accessible_by_all_users(
|
||||
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
|
||||
) -> None:
|
||||
"""Test that public documents are accessible by both test users using ACL verification"""
|
||||
(
|
||||
admin_user,
|
||||
regular_user_1,
|
||||
regular_user_2,
|
||||
credential,
|
||||
connector,
|
||||
cc_pair,
|
||||
) = sharepoint_test_env_setup
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(cc_pair, db_session)
|
||||
|
||||
# Test that regular_user_1 can access documents
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Test that regular_user_2 can access documents
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
|
||||
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
|
||||
|
||||
# For public documents, both users should have access to at least some docs
|
||||
assert len(accessible_docs_user1) == 8, (
|
||||
f"User 1 should have access to documents. Found "
|
||||
f"{len(accessible_docs_user1)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
assert len(accessible_docs_user2) == 1, (
|
||||
f"User 2 should have access to documents. Found "
|
||||
f"{len(accessible_docs_user2)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Successfully verified public documents are accessible by users via ACL"
|
||||
)
|
||||
|
||||
|
||||
def test_group_based_permissions(
|
||||
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
|
||||
) -> None:
|
||||
"""Test that documents with group permissions are accessible only by users in that group using ACL verification"""
|
||||
(
|
||||
admin_user,
|
||||
regular_user_1,
|
||||
regular_user_2,
|
||||
credential,
|
||||
connector,
|
||||
cc_pair,
|
||||
) = sharepoint_test_env_setup
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(cc_pair, db_session)
|
||||
|
||||
if not all_document_ids:
|
||||
pytest.skip("No documents found for connector - skipping test")
|
||||
|
||||
# Test access for both users
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
|
||||
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
|
||||
|
||||
public_docs = get_documents_by_permission_type(all_document_ids, db_session)
|
||||
|
||||
# Check if user 2 has access to any non-public documents
|
||||
non_public_access_user2 = [
|
||||
doc for doc in accessible_docs_user2 if doc not in public_docs
|
||||
]
|
||||
|
||||
assert (
|
||||
len(non_public_access_user2) == 0
|
||||
), f"User 2 should only have access to public documents. Found access to non-public docs: {non_public_access_user2}"
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
useField,
|
||||
useFormikContext,
|
||||
} from "formik";
|
||||
import { FileUpload } from "@/components/admin/connectors/FileUpload";
|
||||
import * as Yup from "yup";
|
||||
import { FormBodyBuilder } from "./admin/connectors/types";
|
||||
import { StringOrNumberOption } from "@/components/Dropdown";
|
||||
@@ -37,6 +38,12 @@ import { transformLinkUri } from "@/lib/utils";
|
||||
import FileInput from "@/app/admin/connectors/[connector]/pages/ConnectorInput/FileInput";
|
||||
import { DatePicker } from "./ui/datePicker";
|
||||
import { Textarea, TextareaProps } from "./ui/textarea";
|
||||
import {
|
||||
TypedFile,
|
||||
createTypedFile,
|
||||
getFileTypeDefinitionForField,
|
||||
FILE_TYPE_DEFINITIONS,
|
||||
} from "@/lib/connectors/fileTypes";
|
||||
|
||||
export function SectionHeader({
|
||||
children,
|
||||
@@ -386,6 +393,120 @@ export function FileUploadFormField({
|
||||
);
|
||||
}
|
||||
|
||||
export function TypedFileUploadFormField({
|
||||
name,
|
||||
label,
|
||||
subtext,
|
||||
}: {
|
||||
name: string;
|
||||
label: string;
|
||||
subtext?: string | JSX.Element;
|
||||
}) {
|
||||
const [field, , helpers] = useField<TypedFile | null>(name);
|
||||
const [customError, setCustomError] = useState<string>("");
|
||||
const [isValidating, setIsValidating] = useState(false);
|
||||
const [description, setDescription] = useState<string>("");
|
||||
|
||||
useEffect(() => {
|
||||
const typeDefinitionKey = getFileTypeDefinitionForField(name);
|
||||
if (typeDefinitionKey) {
|
||||
setDescription(
|
||||
FILE_TYPE_DEFINITIONS[typeDefinitionKey].description || ""
|
||||
);
|
||||
}
|
||||
}, [name]);
|
||||
|
||||
useEffect(() => {
|
||||
const validateFile = async () => {
|
||||
if (!field.value) {
|
||||
setIsValidating(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsValidating(true);
|
||||
|
||||
try {
|
||||
const validation = await field.value.validate();
|
||||
if (validation?.isValid) {
|
||||
setCustomError("");
|
||||
} else {
|
||||
setCustomError(validation?.errors.join(", ") || "Unknown error");
|
||||
helpers.setValue(null);
|
||||
}
|
||||
} catch (error) {
|
||||
setCustomError(
|
||||
error instanceof Error ? error.message : "Validation error"
|
||||
);
|
||||
helpers.setValue(null);
|
||||
} finally {
|
||||
setIsValidating(false);
|
||||
}
|
||||
};
|
||||
|
||||
validateFile();
|
||||
}, [field.value, helpers]);
|
||||
|
||||
const handleFileSelection = async (files: File[]) => {
|
||||
if (files.length === 0) {
|
||||
helpers.setValue(null);
|
||||
setCustomError("");
|
||||
return;
|
||||
}
|
||||
|
||||
const file = files[0];
|
||||
if (!file) {
|
||||
setCustomError("File selection error");
|
||||
return;
|
||||
}
|
||||
|
||||
const typeDefinitionKey = getFileTypeDefinitionForField(name);
|
||||
|
||||
if (!typeDefinitionKey) {
|
||||
setCustomError(`No file type definition found for field: ${name}`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const typedFile = createTypedFile(file, name, typeDefinitionKey);
|
||||
helpers.setValue(typedFile);
|
||||
setCustomError("");
|
||||
} catch (error) {
|
||||
setCustomError(error instanceof Error ? error.message : "Unknown error");
|
||||
helpers.setValue(null);
|
||||
} finally {
|
||||
setIsValidating(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<FieldLabel name={name} label={label} subtext={subtext} />
|
||||
{description && (
|
||||
<div className="text-sm text-gray-500 mb-2">{description}</div>
|
||||
)}
|
||||
<FileUpload
|
||||
selectedFiles={field.value ? [field.value.file] : []}
|
||||
setSelectedFiles={handleFileSelection}
|
||||
multiple={false}
|
||||
/>
|
||||
{/* Validation feedback */}
|
||||
{isValidating && (
|
||||
<div className="text-blue-500 text-sm mt-1">Validating file...</div>
|
||||
)}
|
||||
|
||||
{customError ? (
|
||||
<div className="text-red-500 text-sm mt-1">{customError}</div>
|
||||
) : (
|
||||
<ErrorMessage
|
||||
name={name}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function MultiSelectField({
|
||||
name,
|
||||
label,
|
||||
|
||||
@@ -4,11 +4,20 @@ import * as Yup from "yup";
|
||||
import { Popup } from "./Popup";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
import { createCredential } from "@/lib/credential";
|
||||
import { CredentialBase, Credential } from "@/lib/connectors/credentials";
|
||||
import {
|
||||
createCredential,
|
||||
createCredentialWithPrivateKey,
|
||||
} from "@/lib/credential";
|
||||
import {
|
||||
CredentialBase,
|
||||
Credential,
|
||||
CredentialWithPrivateKey,
|
||||
} from "@/lib/connectors/credentials";
|
||||
|
||||
const PRIVATE_KEY_FIELD_KEY = "private_key";
|
||||
|
||||
export async function submitCredential<T>(
|
||||
credential: CredentialBase<T>
|
||||
credential: CredentialBase<T> | CredentialWithPrivateKey<T>
|
||||
): Promise<{
|
||||
credential?: Credential<any>;
|
||||
message: string;
|
||||
@@ -16,8 +25,14 @@ export async function submitCredential<T>(
|
||||
}> {
|
||||
let isSuccess = false;
|
||||
try {
|
||||
const response = await createCredential(credential);
|
||||
|
||||
let response: Response;
|
||||
if (PRIVATE_KEY_FIELD_KEY in credential && credential.private_key) {
|
||||
response = await createCredentialWithPrivateKey(
|
||||
credential as CredentialWithPrivateKey<T>
|
||||
);
|
||||
} else {
|
||||
response = await createCredential(credential as CredentialBase<T>);
|
||||
}
|
||||
if (response.ok) {
|
||||
const parsed_response = await response.json();
|
||||
const credential = parsed_response.credential;
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
deleteCredential,
|
||||
swapCredential,
|
||||
updateCredential,
|
||||
updateCredentialWithPrivateKey,
|
||||
} from "@/lib/credential";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import CreateCredential from "./actions/CreateCredential";
|
||||
@@ -34,6 +35,7 @@ import {
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||
import { Card } from "../ui/card";
|
||||
import { isTypedFileField, TypedFile } from "@/lib/connectors/fileTypes";
|
||||
|
||||
export default function CredentialSection({
|
||||
ccPair,
|
||||
@@ -111,7 +113,23 @@ export default function CredentialSection({
|
||||
details: any,
|
||||
onSucces: () => void
|
||||
) => {
|
||||
const response = await updateCredential(selectedCredential.id, details);
|
||||
let privateKey: TypedFile | null = null;
|
||||
Object.entries(details).forEach(([key, value]) => {
|
||||
if (isTypedFileField(key)) {
|
||||
privateKey = value as TypedFile;
|
||||
delete details[key];
|
||||
}
|
||||
});
|
||||
let response;
|
||||
if (privateKey) {
|
||||
response = await updateCredentialWithPrivateKey(
|
||||
selectedCredential.id,
|
||||
details,
|
||||
privateKey
|
||||
);
|
||||
} else {
|
||||
response = await updateCredential(selectedCredential.id, details);
|
||||
}
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Updated credential",
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { CredentialFieldsRenderer } from "./CredentialFieldsRenderer";
|
||||
import { TypedFile } from "@/lib/connectors/fileTypes";
|
||||
|
||||
const CreateButton = ({
|
||||
onClick,
|
||||
@@ -114,10 +115,15 @@ export default function CreateCredential({
|
||||
|
||||
const { name, is_public, groups, ...credentialValues } = values;
|
||||
|
||||
let privateKey: TypedFile | null = null;
|
||||
const filteredCredentialValues = Object.fromEntries(
|
||||
Object.entries(credentialValues).filter(
|
||||
([_, value]) => value !== null && value !== ""
|
||||
)
|
||||
Object.entries(credentialValues).filter(([key, value]) => {
|
||||
if (value instanceof TypedFile) {
|
||||
privateKey = value;
|
||||
return false;
|
||||
}
|
||||
return value !== null && value !== "";
|
||||
})
|
||||
);
|
||||
|
||||
try {
|
||||
@@ -128,6 +134,7 @@ export default function CreateCredential({
|
||||
groups: groups,
|
||||
name: name,
|
||||
source: sourceType,
|
||||
private_key: privateKey || undefined,
|
||||
});
|
||||
|
||||
const { message, isSuccess, credential } = response;
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import React from "react";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { useFormikContext } from "formik";
|
||||
import { BooleanFormField, TextFormField } from "@/components/Field";
|
||||
import {
|
||||
BooleanFormField,
|
||||
TextFormField,
|
||||
TypedFileUploadFormField,
|
||||
} from "@/components/Field";
|
||||
import {
|
||||
getDisplayNameForCredentialKey,
|
||||
CredentialTemplateWithAuth,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { dictionaryType } from "../types";
|
||||
import { isTypedFileField } from "@/lib/connectors/fileTypes";
|
||||
|
||||
interface CredentialFieldsRendererProps {
|
||||
credentialTemplate: dictionaryType;
|
||||
@@ -90,6 +95,16 @@ export function CredentialFieldsRenderer({
|
||||
)}
|
||||
|
||||
{Object.entries(method.fields).map(([key, val]) => {
|
||||
if (isTypedFileField(key)) {
|
||||
return (
|
||||
<TypedFileUploadFormField
|
||||
key={key}
|
||||
name={key}
|
||||
label={getDisplayNameForCredentialKey(key)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof val === "boolean") {
|
||||
return (
|
||||
<BooleanFormField
|
||||
@@ -130,6 +145,15 @@ export function CredentialFieldsRenderer({
|
||||
if (key === "authentication_method" || key === "authMethods") {
|
||||
return null;
|
||||
}
|
||||
if (isTypedFileField(key)) {
|
||||
return (
|
||||
<TypedFileUploadFormField
|
||||
key={key}
|
||||
name={key}
|
||||
label={getDisplayNameForCredentialKey(key)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof val === "boolean") {
|
||||
return (
|
||||
@@ -144,7 +168,7 @@ export function CredentialFieldsRenderer({
|
||||
<TextFormField
|
||||
key={key}
|
||||
name={key}
|
||||
placeholder={val}
|
||||
placeholder={val as string}
|
||||
label={getDisplayNameForCredentialKey(key)}
|
||||
type={
|
||||
key.toLowerCase().includes("token") ||
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Button } from "@/components/ui/button";
|
||||
import Text from "@/components/ui/text";
|
||||
|
||||
import { FaNewspaper, FaTrash } from "react-icons/fa";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { TextFormField, TypedFileUploadFormField } from "@/components/Field";
|
||||
import { Form, Formik, FormikHelpers } from "formik";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import {
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { createEditingValidationSchema, createInitialValues } from "../lib";
|
||||
import { dictionaryType, formType } from "../types";
|
||||
import { isTypedFileField } from "@/lib/connectors/fileTypes";
|
||||
|
||||
const EditCredential = ({
|
||||
credential,
|
||||
@@ -68,22 +69,30 @@ const EditCredential = ({
|
||||
label="Name (optional):"
|
||||
/>
|
||||
|
||||
{Object.entries(credential.credential_json).map(([key, value]) => (
|
||||
<TextFormField
|
||||
includeRevert
|
||||
key={key}
|
||||
name={key}
|
||||
placeholder={value}
|
||||
label={getDisplayNameForCredentialKey(key)}
|
||||
type={
|
||||
key.toLowerCase().includes("token") ||
|
||||
key.toLowerCase().includes("password")
|
||||
? "password"
|
||||
: "text"
|
||||
}
|
||||
disabled={key === "authentication_method"}
|
||||
/>
|
||||
))}
|
||||
{Object.entries(credential.credential_json).map(([key, value]) =>
|
||||
isTypedFileField(key) ? (
|
||||
<TypedFileUploadFormField
|
||||
key={key}
|
||||
name={key}
|
||||
label={getDisplayNameForCredentialKey(key)}
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
includeRevert
|
||||
key={key}
|
||||
name={key}
|
||||
placeholder={value as string}
|
||||
label={getDisplayNameForCredentialKey(key)}
|
||||
type={
|
||||
key.toLowerCase().includes("token") ||
|
||||
key.toLowerCase().includes("password")
|
||||
? "password"
|
||||
: "text"
|
||||
}
|
||||
disabled={key === "authentication_method"}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
<div className="flex justify-between w-full">
|
||||
<Button type="button" onClick={() => resetForm()}>
|
||||
<div className="flex gap-x-2 items-center w-full border-none">
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
getDisplayNameForCredentialKey,
|
||||
CredentialTemplateWithAuth,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { isTypedFileField } from "@/lib/connectors/fileTypes";
|
||||
|
||||
export function createValidationSchema(json_values: Record<string, any>) {
|
||||
const schemaFields: Record<string, Yup.AnySchema> = {};
|
||||
@@ -16,7 +17,6 @@ export function createValidationSchema(json_values: Record<string, any>) {
|
||||
schemaFields["authentication_method"] = Yup.string().required(
|
||||
"Please select an authentication method"
|
||||
);
|
||||
|
||||
// conditional rules per authMethod
|
||||
template.authMethods.forEach((method) => {
|
||||
Object.entries(method.fields).forEach(([key, def]) => {
|
||||
@@ -26,6 +26,14 @@ export function createValidationSchema(json_values: Record<string, any>) {
|
||||
.nullable()
|
||||
.default(false)
|
||||
.transform((v, o) => (o === undefined ? false : v));
|
||||
} else if (isTypedFileField(key)) {
|
||||
//TypedFile fields - use mixed schema instead of string (check before null check)
|
||||
schemaFields[key] = Yup.mixed().when("authentication_method", {
|
||||
is: method.value,
|
||||
then: () =>
|
||||
Yup.mixed().required(`Please select a ${displayName} file`),
|
||||
otherwise: () => Yup.mixed().notRequired(),
|
||||
});
|
||||
} else if (def === null) {
|
||||
schemaFields[key] = Yup.string()
|
||||
.trim()
|
||||
@@ -58,6 +66,11 @@ export function createValidationSchema(json_values: Record<string, any>) {
|
||||
.nullable()
|
||||
.default(false)
|
||||
.transform((v, o) => (o === undefined ? false : v));
|
||||
} else if (isTypedFileField(key)) {
|
||||
// TypedFile fields - use mixed schema instead of string (check before null check)
|
||||
schemaFields[key] = Yup.mixed().required(
|
||||
`Please select a ${displayName} file`
|
||||
);
|
||||
} else if (def === null) {
|
||||
schemaFields[key] = Yup.string()
|
||||
.trim()
|
||||
@@ -77,11 +90,16 @@ export function createValidationSchema(json_values: Record<string, any>) {
|
||||
}
|
||||
|
||||
export function createEditingValidationSchema(json_values: dictionaryType) {
|
||||
const schemaFields: { [key: string]: Yup.StringSchema } = {};
|
||||
const schemaFields: { [key: string]: Yup.AnySchema } = {};
|
||||
|
||||
for (const key in json_values) {
|
||||
if (Object.prototype.hasOwnProperty.call(json_values, key)) {
|
||||
schemaFields[key] = Yup.string().optional();
|
||||
if (isTypedFileField(key)) {
|
||||
// TypedFile fields - use mixed schema for optional file uploads during editing
|
||||
schemaFields[key] = Yup.mixed().optional();
|
||||
} else {
|
||||
schemaFields[key] = Yup.string().optional();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +113,12 @@ export function createInitialValues(credential: Credential<any>): formType {
|
||||
};
|
||||
|
||||
for (const key in credential.credential_json) {
|
||||
initialValues[key] = "";
|
||||
// Initialize TypedFile fields as null, other fields as empty strings
|
||||
if (isTypedFileField(key)) {
|
||||
initialValues[key] = null as any; // TypedFile fields start as null
|
||||
} else {
|
||||
initialValues[key] = "";
|
||||
}
|
||||
}
|
||||
|
||||
return initialValues;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { TypedFile } from "@/lib/connectors/fileTypes";
|
||||
|
||||
export interface dictionaryType {
|
||||
[key: string]: string;
|
||||
[key: string]: string | TypedFile;
|
||||
}
|
||||
export interface formType extends dictionaryType {
|
||||
name: string;
|
||||
|
||||
@@ -17,4 +17,5 @@ export const autoSyncConfigBySource: Record<
|
||||
github: {},
|
||||
slack: {},
|
||||
salesforce: {},
|
||||
sharepoint: {},
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { ValidSources } from "../types";
|
||||
import { TypedFile } from "./fileTypes";
|
||||
|
||||
export interface OAuthAdditionalKwargDescription {
|
||||
name: string;
|
||||
@@ -30,6 +31,10 @@ export interface CredentialBase<T> {
|
||||
groups?: number[];
|
||||
}
|
||||
|
||||
export interface CredentialWithPrivateKey<T> extends CredentialBase<T> {
|
||||
private_key: TypedFile;
|
||||
}
|
||||
|
||||
export interface Credential<T> extends CredentialBase<T> {
|
||||
id: number;
|
||||
user_id: string | null;
|
||||
@@ -188,8 +193,10 @@ export interface SalesforceCredentialJson {
|
||||
|
||||
export interface SharepointCredentialJson {
|
||||
sp_client_id: string;
|
||||
sp_client_secret: string;
|
||||
sp_client_secret?: string;
|
||||
sp_directory_id: string;
|
||||
sp_certificate_password?: string;
|
||||
sp_private_key?: TypedFile;
|
||||
}
|
||||
|
||||
export interface AsanaCredentialJson {
|
||||
@@ -297,10 +304,33 @@ export const credentialTemplates: Record<ValidSources, any> = {
|
||||
is_sandbox: false,
|
||||
} as SalesforceCredentialJson,
|
||||
sharepoint: {
|
||||
sp_client_id: "",
|
||||
sp_client_secret: "",
|
||||
sp_directory_id: "",
|
||||
} as SharepointCredentialJson,
|
||||
authentication_method: "client_credentials",
|
||||
authMethods: [
|
||||
{
|
||||
value: "client_secret",
|
||||
label: "Client Secret",
|
||||
fields: {
|
||||
sp_client_id: "",
|
||||
sp_client_secret: "",
|
||||
sp_directory_id: "",
|
||||
},
|
||||
description:
|
||||
"If you select this mode, the SharePoint connector will use a client secret to authenticate. You will need to provide the client ID and client secret.",
|
||||
},
|
||||
{
|
||||
value: "certificate",
|
||||
label: "Certificate Authentication",
|
||||
fields: {
|
||||
sp_client_id: "",
|
||||
sp_directory_id: "",
|
||||
sp_certificate_password: "",
|
||||
sp_private_key: null,
|
||||
},
|
||||
description:
|
||||
"If you select this mode, the SharePoint connector will use a certificate to authenticate. You will need to provide the client ID, directory ID, certificate password, and PFX data.",
|
||||
},
|
||||
],
|
||||
} as CredentialTemplateWithAuth<SharepointCredentialJson>,
|
||||
asana: {
|
||||
asana_api_token_secret: "",
|
||||
} as AsanaCredentialJson,
|
||||
@@ -522,6 +552,8 @@ export const credentialDisplayNames: Record<string, string> = {
|
||||
sp_client_id: "SharePoint Client ID",
|
||||
sp_client_secret: "SharePoint Client Secret",
|
||||
sp_directory_id: "SharePoint Directory ID",
|
||||
sp_certificate_password: "SharePoint Certificate Password",
|
||||
sp_private_key: "SharePoint Private Key",
|
||||
|
||||
// Asana
|
||||
asana_api_token_secret: "Asana API Token",
|
||||
|
||||
117
web/src/lib/connectors/fileTypes.ts
Normal file
117
web/src/lib/connectors/fileTypes.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
export enum FileTypeCategory {
|
||||
SHAREPOINT_PFX_FILE = "sharepoint_pfx_file",
|
||||
}
|
||||
|
||||
export interface FileValidationRule {
|
||||
maxSizeKB?: number;
|
||||
allowedExtensions?: string[];
|
||||
contentValidation?: (file: File) => Promise<boolean>;
|
||||
}
|
||||
|
||||
export interface FileTypeDefinition {
|
||||
category: FileTypeCategory;
|
||||
validation?: FileValidationRule;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export const FILE_TYPE_DEFINITIONS: Record<
|
||||
FileTypeCategory,
|
||||
FileTypeDefinition
|
||||
> = {
|
||||
[FileTypeCategory.SHAREPOINT_PFX_FILE]: {
|
||||
category: FileTypeCategory.SHAREPOINT_PFX_FILE,
|
||||
validation: {
|
||||
maxSizeKB: 10,
|
||||
allowedExtensions: [".pfx"],
|
||||
},
|
||||
description:
|
||||
"Please upload a .pfx file containing the private key for SharePoint. The file size must be under 10KB.",
|
||||
},
|
||||
};
|
||||
|
||||
export class TypedFile {
|
||||
constructor(
|
||||
public readonly file: File,
|
||||
public readonly typeDefinition: FileTypeDefinition,
|
||||
public readonly fieldKey: string
|
||||
) {}
|
||||
|
||||
async validate(): Promise<{ isValid: boolean; errors: string[] }> {
|
||||
const errors: string[] = [];
|
||||
const { validation } = this.typeDefinition;
|
||||
|
||||
if (!validation) {
|
||||
return {
|
||||
isValid: true,
|
||||
errors: [],
|
||||
};
|
||||
}
|
||||
|
||||
// Size validation
|
||||
if (validation.maxSizeKB && this.file.size > validation.maxSizeKB * 1024) {
|
||||
errors.push(`File size must not exceed ${validation.maxSizeKB}KB`);
|
||||
}
|
||||
|
||||
// Extension validation
|
||||
if (validation.allowedExtensions) {
|
||||
const extension = this.file.name.toLowerCase().split(".").pop();
|
||||
if (
|
||||
!extension ||
|
||||
!validation.allowedExtensions.includes(`.${extension}`)
|
||||
) {
|
||||
errors.push(
|
||||
`File must have one of these extensions: ${validation.allowedExtensions.join(", ")}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Content validation
|
||||
if (validation.contentValidation) {
|
||||
try {
|
||||
const isContentValid = await validation.contentValidation(this.file);
|
||||
if (!isContentValid) {
|
||||
errors.push(`File content validation failed`);
|
||||
}
|
||||
} catch (error) {
|
||||
errors.push(
|
||||
`Content validation error: ${error instanceof Error ? error.message : "Unknown error"}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: errors.length === 0,
|
||||
errors,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function createTypedFile(
|
||||
file: File,
|
||||
fieldKey: string,
|
||||
typeDefinitionKey: FileTypeCategory
|
||||
): TypedFile {
|
||||
const typeDefinition = FILE_TYPE_DEFINITIONS[typeDefinitionKey];
|
||||
if (!typeDefinition) {
|
||||
throw new Error(`Unknown file type definition: ${typeDefinitionKey}`);
|
||||
}
|
||||
|
||||
return new TypedFile(file, typeDefinition, fieldKey);
|
||||
}
|
||||
|
||||
export function isTypedFileField(fieldKey: string): boolean {
|
||||
// Define which fields should be typed files
|
||||
const typedFileFields = new Set(["sp_private_key"]);
|
||||
return typedFileFields.has(fieldKey);
|
||||
}
|
||||
|
||||
// Get the appropriate file type definition for a field
|
||||
export function getFileTypeDefinitionForField(
|
||||
fieldKey: string
|
||||
): FileTypeCategory | null {
|
||||
const fieldToTypeMap: Record<string, FileTypeCategory> = {
|
||||
sp_private_key: FileTypeCategory.SHAREPOINT_PFX_FILE,
|
||||
};
|
||||
|
||||
return fieldToTypeMap[fieldKey] || null;
|
||||
}
|
||||
@@ -108,3 +108,11 @@ export const ALLOWED_URL_PROTOCOLS = [
|
||||
];
|
||||
|
||||
export const MAX_CHARACTERS_PERSONA_DESCRIPTION = 5000000;
|
||||
|
||||
//Credential form data key constants
|
||||
export const CREDENTIAL_NAME = "name";
|
||||
export const CREDENTIAL_SOURCE = "source";
|
||||
export const CREDENTIAL_UPLOADED_FILE = "uploaded_file";
|
||||
export const CREDENTIAL_FIELD_KEY = "field_key";
|
||||
export const CREDENTIAL_TYPE_DEFINITION_KEY = "type_definition_key";
|
||||
export const CREDENTIAL_JSON = "credential_json";
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
import { CredentialBase } from "./connectors/credentials";
|
||||
import {
|
||||
CredentialBase,
|
||||
CredentialWithPrivateKey,
|
||||
} from "./connectors/credentials";
|
||||
import { AccessType } from "@/lib/types";
|
||||
import { TypedFile } from "./connectors/fileTypes";
|
||||
import {
|
||||
CREDENTIAL_NAME,
|
||||
CREDENTIAL_SOURCE,
|
||||
CREDENTIAL_UPLOADED_FILE,
|
||||
CREDENTIAL_FIELD_KEY,
|
||||
CREDENTIAL_TYPE_DEFINITION_KEY,
|
||||
CREDENTIAL_JSON,
|
||||
} from "./constants";
|
||||
|
||||
export async function createCredential(credential: CredentialBase<any>) {
|
||||
return await fetch(`/api/manage/credential`, {
|
||||
@@ -11,6 +23,37 @@ export async function createCredential(credential: CredentialBase<any>) {
|
||||
});
|
||||
}
|
||||
|
||||
export async function createCredentialWithPrivateKey(
|
||||
credential: CredentialWithPrivateKey<any>
|
||||
) {
|
||||
const formData = new FormData();
|
||||
formData.append(CREDENTIAL_JSON, JSON.stringify(credential.credential_json));
|
||||
formData.append("admin_public", credential.admin_public.toString());
|
||||
formData.append(
|
||||
"curator_public",
|
||||
credential.curator_public?.toString() || "false"
|
||||
);
|
||||
if (credential.groups && credential.groups.length > 0) {
|
||||
credential.groups.forEach((group) => {
|
||||
formData.append("groups", String(group));
|
||||
});
|
||||
}
|
||||
formData.append(CREDENTIAL_NAME, credential.name || "");
|
||||
formData.append(CREDENTIAL_SOURCE, credential.source);
|
||||
if (credential.private_key) {
|
||||
formData.append(CREDENTIAL_UPLOADED_FILE, credential.private_key.file);
|
||||
formData.append(CREDENTIAL_FIELD_KEY, credential.private_key.fieldKey);
|
||||
formData.append(
|
||||
CREDENTIAL_TYPE_DEFINITION_KEY,
|
||||
credential.private_key.typeDefinition.category
|
||||
);
|
||||
}
|
||||
return await fetch(`/api/manage/credential/private-key`, {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
}
|
||||
|
||||
export async function adminDeleteCredential<T>(credentialId: number) {
|
||||
return await fetch(`/api/manage/admin/credential/${credentialId}`, {
|
||||
method: "DELETE",
|
||||
@@ -70,7 +113,7 @@ export function updateCredential(credentialId: number, newDetails: any) {
|
||||
const name = newDetails.name;
|
||||
const details = Object.fromEntries(
|
||||
Object.entries(newDetails).filter(
|
||||
([key, value]) => key !== "name" && value !== ""
|
||||
([key, value]) => key !== CREDENTIAL_NAME && value !== ""
|
||||
)
|
||||
);
|
||||
return fetch(`/api/manage/admin/credential/${credentialId}`, {
|
||||
@@ -85,6 +128,32 @@ export function updateCredential(credentialId: number, newDetails: any) {
|
||||
});
|
||||
}
|
||||
|
||||
export function updateCredentialWithPrivateKey(
|
||||
credentialId: number,
|
||||
newDetails: any,
|
||||
privateKey: TypedFile
|
||||
) {
|
||||
const name = newDetails.name;
|
||||
const details = Object.fromEntries(
|
||||
Object.entries(newDetails).filter(
|
||||
([key, value]) => key !== CREDENTIAL_NAME && value !== ""
|
||||
)
|
||||
);
|
||||
const formData = new FormData();
|
||||
formData.append(CREDENTIAL_NAME, name);
|
||||
formData.append(CREDENTIAL_JSON, JSON.stringify(details));
|
||||
formData.append(CREDENTIAL_UPLOADED_FILE, privateKey.file);
|
||||
formData.append(CREDENTIAL_FIELD_KEY, privateKey.fieldKey);
|
||||
formData.append(
|
||||
CREDENTIAL_TYPE_DEFINITION_KEY,
|
||||
privateKey.typeDefinition.category
|
||||
);
|
||||
return fetch(`/api/manage/admin/credential/private-key/${credentialId}`, {
|
||||
method: "PUT",
|
||||
body: formData,
|
||||
});
|
||||
}
|
||||
|
||||
export function swapCredential(
|
||||
newCredentialId: number,
|
||||
connectorId: number,
|
||||
|
||||
@@ -456,6 +456,7 @@ export const validAutoSyncSources = [
|
||||
ValidSources.Slack,
|
||||
ValidSources.Salesforce,
|
||||
ValidSources.GitHub,
|
||||
ValidSources.Sharepoint,
|
||||
] as const;
|
||||
|
||||
// Create a type from the array elements
|
||||
|
||||
Reference in New Issue
Block a user