1
0
forked from github/onyx

feat: sharepoint perm sync (#5033)

* sharepoint perm sync first draft

* feat: Implement SharePoint permission synchronization

* mypy fix

* remove commented code

* bot comments fixes and job failure fixes

* introduce generic way to upload certificates in credentials

* mypy fix

* add checkpoiting to sharepoint connector

* add sharepoint integration tests

* Refactor SharePoint connector to derive tenant domain from verified domains and remove direct tenant domain input from credentials

* address review comments

* add permission sync to site pages

* mypy fix

* fix tests error

* fix tests and address comments

* Update file extraction behavior in SharePoint connector to continue processing on unprocessable files
This commit is contained in:
SubashMohan
2025-08-11 22:29:16 +05:30
committed by GitHub
parent bf6705a9a5
commit 9bc62cc803
27 changed files with 2926 additions and 262 deletions

View File

@@ -102,6 +102,19 @@ TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
#####
# SharePoint
#####
# In seconds, default is 30 minutes
SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
)
# In seconds, default is 5 minutes
SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
####
# Celery Job Frequency

View File

@@ -0,0 +1,36 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
SHAREPOINT_DOC_SYNC_TAG = "sharepoint_doc_sync"
def sharepoint_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
sharepoint_connector = SharepointConnector(
**cc_pair.connector.connector_specific_config,
)
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.SHAREPOINT,
slim_connector=sharepoint_connector,
label=SHAREPOINT_DOC_SYNC_TAG,
)

View File

@@ -0,0 +1,63 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def sharepoint_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
"""Sync SharePoint groups and their members"""
# Get site URLs from connector config
connector_config = cc_pair.connector.connector_specific_config
# Create SharePoint connector instance and load credentials
connector = SharepointConnector(**connector_config)
connector.load_credentials(cc_pair.credential.credential_json)
if not connector.msal_app:
raise RuntimeError("MSAL app not initialized in connector")
if not connector.sp_tenant_domain:
raise RuntimeError("Tenant domain not initialized in connector")
# Get site descriptors from connector (either configured sites or all sites)
site_descriptors = connector.site_descriptors or connector.fetch_sites()
if not site_descriptors:
raise RuntimeError("No SharePoint sites found for group sync")
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:
logger.debug(
f"Found group: {group.id} with {len(group.user_emails)} members"
)
yield group

View File

@@ -0,0 +1,658 @@
import re
from collections import deque
from typing import Any
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
logger = setup_logger()
# These values represent different types of SharePoint principals used in permission assignments
USER_PRINCIPAL_TYPE = 1 # Individual user accounts
ANONYMOUS_USER_PRINCIPAL_TYPE = 3 # Anonymous/unauthenticated users (public access)
AZURE_AD_GROUP_PRINCIPAL_TYPE = 4 # Azure Active Directory security groups
SHAREPOINT_GROUP_PRINCIPAL_TYPE = 8 # SharePoint site groups (local to the site)
MICROSOFT_DOMAIN = ".onmicrosoft"
# Limited Access role type, limited access is a travel through permission not a actual permission
LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
name: str
login_name: str
principal_type: int
class GroupsResult(BaseModel):
groups_to_emails: dict[str, set[str]]
found_public_group: bool
def _get_azuread_group_guid_by_name(
graph_client: GraphClient, group_name: str
) -> str | None:
try:
# Search for groups by display name
groups = sleep_and_retry(
graph_client.groups.filter(f"displayName eq '{group_name}'").get(),
"get_azuread_group_guid_by_name",
)
if groups and len(groups) > 0:
return groups[0].id
return None
except Exception as e:
logger.error(f"Failed to get Azure AD group GUID for name {group_name}: {e}")
return None
def _extract_guid_from_claims_token(claims_token: str) -> str | None:
try:
# Pattern to match GUID in claims token
# Claims tokens often have format: c:0o.c|provider|GUID_suffix
guid_pattern = r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})"
match = re.search(guid_pattern, claims_token, re.IGNORECASE)
if match:
return match.group(1)
return None
except Exception as e:
logger.error(f"Failed to extract GUID from claims token {claims_token}: {e}")
return None
def _get_group_guid_from_identifier(
graph_client: GraphClient, identifier: str
) -> str | None:
try:
# Check if it's already a GUID
guid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
if re.match(guid_pattern, identifier, re.IGNORECASE):
return identifier
# Check if it's a SharePoint claims token
if identifier.startswith("c:0") and "|" in identifier:
guid = _extract_guid_from_claims_token(identifier)
if guid:
logger.info(f"Extracted GUID {guid} from claims token {identifier}")
return guid
# Try to search by display name as fallback
return _get_azuread_group_guid_by_name(graph_client, identifier)
except Exception as e:
logger.error(f"Failed to get group GUID from identifier {identifier}: {e}")
return None
def _get_security_group_owners(graph_client: GraphClient, group_id: str) -> list[str]:
try:
# Get group owners using Graph API
group = graph_client.groups[group_id]
owners = sleep_and_retry(
group.owners.get_all(page_loaded=lambda _: None),
"get_security_group_owners",
)
owner_emails: list[str] = []
logger.info(f"Owners: {owners}")
for owner in owners:
owner_data = owner.to_json()
# Extract email from the JSON data
mail: str | None = owner_data.get("mail")
user_principal_name: str | None = owner_data.get("userPrincipalName")
# Check if owner is a user and has an email
if mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
owner_emails.append(mail)
elif user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
owner_emails.append(user_principal_name)
logger.info(
f"Retrieved {len(owner_emails)} owners from security group {group_id}"
)
return owner_emails
except Exception as e:
logger.error(f"Failed to get security group owners for group {group_id}: {e}")
return []
def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
try:
# First try to get the list item directly from the drive item
if hasattr(drive_item, "listItem"):
list_item = drive_item.listItem
if list_item:
# Load the list item properties to get the ID
sleep_and_retry(list_item.get(), "get_sharepoint_list_item_id")
if hasattr(list_item, "id") and list_item.id:
return str(list_item.id)
# The SharePoint list item ID is typically available in the sharepointIds property
sharepoint_ids = getattr(drive_item, "sharepoint_ids", None)
if sharepoint_ids and hasattr(sharepoint_ids, "listItemId"):
return sharepoint_ids.listItemId
# Alternative: try to get it from the properties
properties = getattr(drive_item, "properties", None)
if properties:
# Sometimes the SharePoint list item ID is in the properties
for prop_name, prop_value in properties.items():
if "listitemid" in prop_name.lower():
return str(prop_value)
return None
except Exception as e:
logger.error(
f"Error getting SharePoint list item ID for item {drive_item.id}: {e}"
)
raise e
def _is_public_item(drive_item: DriveItem) -> bool:
is_public = False
try:
permissions = sleep_and_retry(
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
)
for permission in permissions:
if permission.link and (
permission.link.scope == "anonymous"
or permission.link.scope == "organization"
):
is_public = True
break
return is_public
except Exception as e:
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
return False
def _is_public_login_name(login_name: str) -> bool:
# Patterns that indicate public access
# This list is derived from the below link
# https://learn.microsoft.com/en-us/answers/questions/2085339/guid-in-the-loginname-of-site-user-everyone-except
public_login_patterns: list[str] = [
"c:0-.f|rolemanager|spo-grid-all-users/",
"c:0(.s|true",
]
for pattern in public_login_patterns:
if pattern in login_name:
logger.info(f"Login name {login_name} is public")
return True
return False
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
def _get_group_name_with_suffix(
login_name: str, group_name: str, graph_client: GraphClient
) -> str:
ad_group_suffix = _get_group_guid_from_identifier(graph_client, login_name)
return f"{group_name}_{ad_group_suffix}"
def _get_sharepoint_groups(
client_context: ClientContext, group_name: str, graph_client: GraphClient
) -> tuple[set[SharepointGroup], set[str]]:
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
def process_users(users: list[Any]) -> None:
nonlocal groups, user_emails
for user in users:
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
user, "user_principal_name"
):
if user.user_principal_name:
email = user.user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
else:
logger.warning(
f"User don't have a user principal name: {user.login_name}"
)
elif user.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = user.title
if user.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
user.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=user.login_name,
principal_type=user.principal_type,
name=name,
)
)
group = client_context.web.site_groups.get_by_name(group_name)
sleep_and_retry(
group.users.get_all(page_loaded=process_users), "get_sharepoint_groups"
)
return groups, user_emails
def _get_azuread_groups(
graph_client: GraphClient, group_name: str
) -> tuple[set[SharepointGroup], set[str]]:
group_id = _get_group_guid_from_identifier(graph_client, group_name)
if not group_id:
logger.error(f"Failed to get Azure AD group GUID for name {group_name}")
return set(), set()
group = graph_client.groups[group_id]
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
def process_members(members: list[Any]) -> None:
nonlocal groups, user_emails
for member in members:
member_data = member.to_json()
# Check for user-specific attributes
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
display_name = member_data.get("displayName") or member_data.get(
"display_name"
)
# Check object attributes directly (if available)
is_user = False
is_group = False
# Users typically have userPrincipalName or mail
if user_principal_name or (mail and "@" in str(mail)):
is_user = True
# Groups typically have displayName but no userPrincipalName
elif display_name and not user_principal_name:
# Additional check: try to access group-specific properties
if (
hasattr(member, "groupTypes")
or member_data.get("groupTypes") is not None
):
is_group = True
# Or check if it has an 'id' field typical for groups
elif member_data.get("id") and not user_principal_name:
is_group = True
# Check the object type name (fallback)
if not is_user and not is_group:
obj_type = type(member).__name__.lower()
if "user" in obj_type:
is_user = True
elif "group" in obj_type:
is_group = True
# Process based on identification
if is_user:
if user_principal_name:
email = user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
elif mail:
email = mail
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
logger.info(f"Added user: {user_principal_name or mail}")
elif is_group:
if not display_name:
logger.error(f"No display name for group: {member_data.get('id')}")
continue
name = _get_group_name_with_suffix(
member_data.get("id", ""), display_name, graph_client
)
groups.add(
SharepointGroup(
login_name=member_data.get("id", ""), # Use ID for groups
principal_type=AZURE_AD_GROUP_PRINCIPAL_TYPE,
name=name,
)
)
logger.info(f"Added group: {name}")
else:
# Log unidentified members for debugging
logger.warning(f"Could not identify member type for: {member_data}")
sleep_and_retry(
group.members.get_all(page_loaded=process_members), "get_azuread_groups"
)
owner_emails = _get_security_group_owners(graph_client, group_id)
user_emails.update(owner_emails)
return groups, user_emails
def _get_groups_and_members_recursively(
client_context: ClientContext,
graph_client: GraphClient,
groups: set[SharepointGroup],
) -> GroupsResult:
"""
Get all groups and their members recursively.
"""
group_queue: deque[SharepointGroup] = deque(groups)
visited_groups: set[str] = set()
visited_group_name_to_emails: dict[str, set[str]] = {}
while group_queue:
group = group_queue.popleft()
if group.login_name in visited_groups:
continue
visited_groups.add(group.login_name)
visited_group_name_to_emails[group.name] = set()
logger.info(
f"Processing group: {group.name} principal type: {group.principal_type}"
)
if group.principal_type == SHAREPOINT_GROUP_PRINCIPAL_TYPE:
group_info, user_emails = _get_sharepoint_groups(
client_context, group.login_name, graph_client
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
return GroupsResult(groups_to_emails={}, found_public_group=True)
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
return GroupsResult(
groups_to_emails=visited_group_name_to_emails, found_public_group=False
)
def get_external_access_from_sharepoint(
client_context: ClientContext,
graph_client: GraphClient,
drive_name: str | None,
drive_item: DriveItem | None,
site_page: dict[str, Any] | None,
add_prefix: bool = False,
) -> ExternalAccess:
"""
Get external access information from SharePoint.
"""
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
group_ids: set[str] = set()
# Add all members to a processing set first
def add_user_and_group_to_sets(
role_assignments: RoleAssignmentCollection,
) -> None:
nonlocal user_emails, groups
for assignment in role_assignments:
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
if (
role_definition_binding.role_type_kind
not in LIMITED_ACCESS_ROLE_TYPES
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
):
is_limited_access = False
break
# Skip if the role is only Limited Access, because this is not a actual permission its a travel through permission
if is_limited_access:
logger.info(
"Skipping assignment because it has only Limited Access role"
)
continue
if assignment.member:
member = assignment.member
if member.principal_type == USER_PRINCIPAL_TYPE and hasattr(
member, "user_principal_name"
):
email = member.user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
elif member.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = member.title
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
member.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=member.login_name,
principal_type=member.principal_type,
name=name,
)
)
if drive_item and drive_name:
# Here we check if the item have have any public links, if so we return early
is_public = _is_public_item(drive_item)
if is_public:
logger.info(f"Item {drive_item.id} is public")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
item_id = _get_sharepoint_list_item_id(drive_item)
if not item_id:
raise RuntimeError(
f"Failed to get SharePoint list item ID for item {drive_item.id}"
)
if drive_name == "Shared Documents":
drive_name = "Documents"
item = client_context.web.lists.get_by_title(drive_name).items.get_by_id(
item_id
)
sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
elif site_page:
site_url = site_page.get("webUrl")
site_pages = client_context.web.lists.get_by_title("Site Pages")
client_context.load(site_pages)
client_context.execute_query()
site_pages.items.get_by_url(site_url).role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
else:
raise RuntimeError("No drive item or site page provided")
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
)
# If the site is public, w have default groups assigned to it, so we return early
if groups_and_members.found_public_group:
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
for group_name, _ in groups_and_members.groups_to_emails.items():
if add_prefix:
group_name = build_ext_group_name_for_onyx(
group_name, DocumentSource.SHAREPOINT
)
group_ids.add(group_name.lower())
logger.info(f"User emails: {len(user_emails)}")
logger.info(f"Group IDs: {len(group_ids)}")
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_ids,
is_public=False,
)
def get_sharepoint_external_groups(
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
nonlocal groups
for assignment in role_assignments:
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
if (
role_definition_binding.role_type_kind
not in LIMITED_ACCESS_ROLE_TYPES
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
):
is_limited_access = False
break
# Skip if the role assignment is only Limited Access, because this is not a actual permission its
# a travel through permission
if is_limited_access:
logger.info(
"Skipping assignment because it has only Limited Access role"
)
continue
if assignment.member:
member = assignment.member
if member.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = member.title
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
member.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=member.login_name,
principal_type=member.principal_type,
name=name,
)
)
sleep_and_retry(
client_context.web.role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_group_to_sets),
"get_sharepoint_external_groups",
)
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
)
# We don't have any direct way to check if the site is public, so we check if any public group is present
if groups_and_members.found_public_group:
return []
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -11,6 +11,8 @@ from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
@@ -29,6 +31,8 @@ from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from ee.onyx.external_permissions.sharepoint.doc_sync import sharepoint_doc_sync
from ee.onyx.external_permissions.sharepoint.group_sync import sharepoint_group_sync
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
from onyx.configs.constants import DocumentSource
@@ -156,6 +160,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=True,
),
),
DocumentSource.SHAREPOINT: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=sharepoint_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=sharepoint_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
from typing import Any
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from onyx.connectors.models import ExternalAccess
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def get_sharepoint_external_access(
ctx: ClientContext,
graph_client: GraphClient,
drive_item: DriveItem | None = None,
drive_name: str | None = None,
site_page: dict[str, Any] | None = None,
add_prefix: bool = False,
) -> ExternalAccess:
if drive_item and drive_item.id is None:
raise ValueError("DriveItem ID is required")
# Get external access using the EE implementation
def noop_fallback(*args: Any, **kwargs: Any) -> ExternalAccess:
return ExternalAccess.empty()
get_external_access_func = fetch_versioned_implementation_with_fallback(
"onyx.external_permissions.sharepoint.permission_utils",
"get_external_access_from_sharepoint",
fallback=noop_fallback,
)
external_access = get_external_access_func(
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
)
return external_access

View File

@@ -1,7 +1,12 @@
import json
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Form
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -27,6 +32,9 @@ from onyx.server.documents.models import CredentialDataUpdateRequest
from onyx.server.documents.models import CredentialSnapshot
from onyx.server.documents.models import CredentialSwapRequest
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.private_key_types import FILE_TYPE_TO_FILE_PROCESSOR
from onyx.server.documents.private_key_types import PrivateKeyFileTypes
from onyx.server.documents.private_key_types import ProcessPrivateKeyFileProtocol
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -76,6 +84,7 @@ def get_cc_source_full_info(
document_source=source_type,
get_editable=get_editable,
)
return [
CredentialSnapshot.from_credential_db_model(credential)
for credential in credentials
@@ -149,6 +158,70 @@ def create_credential_from_model(
)
@router.post("/credential/private-key")
def create_credential_with_private_key(
credential_json: str = Form(...),
admin_public: bool = Form(False),
curator_public: bool = Form(False),
groups: list[int] = Form([]),
name: str | None = Form(None),
source: str = Form(...),
user: User | None = Depends(current_curator_or_admin_user),
uploaded_file: UploadFile = File(...),
field_key: str = Form(...),
type_definition_key: str = Form(...),
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
try:
credential_data = json.loads(credential_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON in credential_json: {str(e)}",
)
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
)
if private_key_processor is None:
raise HTTPException(
status_code=400,
detail="Invalid type definition key for private key file",
)
private_key_content: str = private_key_processor(uploaded_file)
credential_data[field_key] = private_key_content
credential_info = CredentialBase(
credential_json=credential_data,
admin_public=admin_public,
curator_public=curator_public,
groups=groups,
name=name,
source=DocumentSource(source),
)
if not _ignore_credential_permissions(DocumentSource(source)):
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
target_group_ids=groups,
object_is_public=curator_public,
)
# Temporary fix for empty Google App credentials
if DocumentSource(source) == DocumentSource.GMAIL:
cleanup_gmail_credentials(db_session=db_session)
credential = create_credential(credential_info, user, db_session)
return ObjectCreationIdResponse(
id=credential.id,
credential=CredentialSnapshot.from_credential_db_model(credential),
)
"""Endpoints for all"""
@@ -209,6 +282,53 @@ def update_credential_data(
return CredentialSnapshot.from_credential_db_model(credential)
@router.put("/admin/credential/private-key/{credential_id}")
def update_credential_private_key(
credential_id: int,
name: str = Form(...),
credential_json: str = Form(...),
uploaded_file: UploadFile = File(...),
field_key: str = Form(...),
type_definition_key: str = Form(...),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CredentialBase:
try:
credential_data = json.loads(credential_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON in credential_json: {str(e)}",
)
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
)
if private_key_processor is None:
raise HTTPException(
status_code=400,
detail="Invalid type definition key for private key file",
)
private_key_content: str = private_key_processor(uploaded_file)
credential_data[field_key] = private_key_content
credential = alter_credential(
credential_id,
name,
credential_data,
user,
db_session,
)
if credential is None:
raise HTTPException(
status_code=401,
detail=f"Credential {credential_id} does not exist or does not belong to user",
)
return CredentialSnapshot.from_credential_db_model(credential)
@router.patch("/credential/{credential_id}")
def update_credential_from_model(
credential_id: int,

View File

@@ -0,0 +1,75 @@
from cryptography.hazmat.primitives.serialization import pkcs12
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _is_password_related_error(error: Exception) -> bool:
"""
Check if the exception indicates a password-related issue rather than a format issue.
"""
error_msg = str(error).lower()
password_keywords = ["mac", "integrity", "password", "authentication", "verify"]
return any(keyword in error_msg for keyword in password_keywords)
def validate_pkcs12_content(file_bytes: bytes) -> bool:
"""
Validate that the file content is actually a PKCS#12 file.
This performs basic format validation without requiring passwords.
"""
try:
# Basic file size check
if len(file_bytes) < 10:
logger.debug("File too small to be a valid PKCS#12 file")
return False
# Check for PKCS#12 magic bytes/ASN.1 structure
# PKCS#12 files start with ASN.1 SEQUENCE tag (0x30)
if file_bytes[0] != 0x30:
logger.debug("File does not start with ASN.1 SEQUENCE tag")
return False
# Try to parse the outer ASN.1 structure without password validation
# This checks if the file has the basic PKCS#12 structure
try:
# Attempt to load just to validate the basic format
# We expect this to fail due to password, but it should fail with a specific error
pkcs12.load_key_and_certificates(file_bytes, password=None)
return True
except ValueError as e:
# Check if the error is related to password (expected) vs format issues
if _is_password_related_error(e):
# These errors indicate the file format is correct but password is wrong/missing
logger.debug(
f"PKCS#12 format appears valid, password-related error: {e}"
)
return True
else:
# Other ValueError likely indicates format issues
logger.debug(f"PKCS#12 format validation failed: {e}")
return False
except Exception as e:
# Try with empty password as fallback
try:
pkcs12.load_key_and_certificates(file_bytes, password=b"")
return True
except ValueError as e2:
if _is_password_related_error(e2):
logger.debug(
f"PKCS#12 format appears valid with empty password attempt: {e2}"
)
return True
else:
logger.debug(
f"PKCS#12 validation failed on both attempts: {e}, {e2}"
)
return False
except Exception:
logger.debug(f"PKCS#12 validation failed: {e}")
return False
except Exception as e:
logger.debug(f"Unexpected error during PKCS#12 validation: {e}")
return False

View File

@@ -0,0 +1,57 @@
import base64
from enum import Enum
from typing import Protocol
from fastapi import HTTPException
from fastapi import UploadFile
from onyx.server.documents.document_utils import validate_pkcs12_content
class ProcessPrivateKeyFileProtocol(Protocol):
def __call__(self, file: UploadFile) -> str:
"""
Accepts a file-like object, validates the file (e.g., checks extension and content),
and returns its contents as a base64-encoded string if valid.
Raises an exception if validation fails.
"""
...
class PrivateKeyFileTypes(Enum):
SHAREPOINT_PFX_FILE = "sharepoint_pfx_file"
def process_sharepoint_private_key_file(file: UploadFile) -> str:
"""
Process and validate a private key file upload.
Validates both the file extension and file content to ensure it's a valid PKCS#12 file.
Content validation prevents attacks that rely on file extension spoofing.
"""
# First check file extension (basic filter)
if not (file.filename and file.filename.lower().endswith(".pfx")):
raise HTTPException(
status_code=400, detail="Invalid file type. Only .pfx files are supported."
)
# Read file content for validation and processing
private_key_bytes = file.file.read()
# Validate file content to prevent extension spoofing attacks
if not validate_pkcs12_content(private_key_bytes):
raise HTTPException(
status_code=400,
detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
)
# Convert to base64 if validation passes
pfx_64 = base64.b64encode(private_key_bytes).decode("ascii")
return pfx_64
FILE_TYPE_TO_FILE_PROCESSOR: dict[
PrivateKeyFileTypes, ProcessPrivateKeyFileProtocol
] = {
PrivateKeyFileTypes.SHAREPOINT_PFX_FILE: process_sharepoint_private_key_file,
}

View File

@@ -1,14 +1,18 @@
import os
import time
from dataclasses import dataclass
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.sharepoint.connector import SharepointConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@dataclass
@@ -76,6 +80,17 @@ def find_document(documents: list[Document], semantic_identifier: str) -> Docume
return matching_docs[0]
@pytest.fixture
def mock_store_image() -> MagicMock:
"""Mock store_image_and_create_section to return a predefined ImageSection."""
mock = MagicMock()
mock.return_value = (
ImageSection(image_file_id="mocked-file-id", link="https://example.com/image"),
"mocked-file-id",
)
return mock
@pytest.fixture
def sharepoint_credentials() -> dict[str, str]:
return {
@@ -87,175 +102,219 @@ def sharepoint_credentials() -> dict[str, str]:
def test_sharepoint_connector_all_sites__docs_only(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with no sites
connector = SharepointConnector(include_site_pages=False)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with no sites
connector = SharepointConnector(include_site_pages=False)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Not asserting expected sites because that can change in test tenant at any time
# Finding any docs is good enough to verify that the connector is working
document_batches = list(connector.load_from_state())
assert document_batches, "Should find documents from all sites"
# Not asserting expected sites because that can change in test tenant at any time
# Finding any docs is good enough to verify that the connector is working
document_batches = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
assert document_batches, "Should find documents from all sites"
def test_sharepoint_connector_specific_folder(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the test site URL and specific folder
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the test site URL and specific folder
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get all documents
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
# Should only find documents in the test folder
test_folder_docs = [
doc
for doc in EXPECTED_DOCUMENTS
if doc.folder_path and doc.folder_path.startswith("test")
]
assert len(found_documents) == len(
test_folder_docs
), "Should only find documents in test folder"
# Should only find documents in the test folder
test_folder_docs = [
doc
for doc in EXPECTED_DOCUMENTS
if doc.folder_path and doc.folder_path.startswith("test")
]
assert len(found_documents) == len(
test_folder_docs
), "Should only find documents in test folder"
# Verify each expected document
for expected in test_folder_docs:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
# Verify each expected document
for expected in test_folder_docs:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_root_folder__docs_only(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get all documents
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
assert len(found_documents) == len(
EXPECTED_DOCUMENTS
), "Should find all documents in main library"
assert len(found_documents) == len(
EXPECTED_DOCUMENTS
), "Should find all documents in main library"
# Verify each expected document
for expected in EXPECTED_DOCUMENTS:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
# Verify each expected document
for expected in EXPECTED_DOCUMENTS:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_other_library(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the other library
connector = SharepointConnector(
sites=[
os.environ["SHAREPOINT_SITE"] + "/Other Library",
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the other library
connector = SharepointConnector(
sites=[
os.environ["SHAREPOINT_SITE"] + "/Other Library",
]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
expected_documents: list[ExpectedDocument] = [
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Should find all documents in `Other Library`
assert len(found_documents) == len(
expected_documents
), "Should find all documents in `Other Library`"
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
expected_documents: list[ExpectedDocument] = [
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
]
# Should find all documents in `Other Library`
assert len(found_documents) == len(
expected_documents
), "Should find all documents in `Other Library`"
# Verify each expected document
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
# Verify each expected document
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_poll(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
start = datetime(
2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc
) # 12 seconds before
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
# Get documents within the time window
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get documents within the time window
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=start.timestamp(),
end=end.timestamp(),
)
# Should only find test1.docx
assert len(found_documents) == 1, "Should only find one document in the time window"
doc = found_documents[0]
assert doc.semantic_identifier == "test1.docx"
verify_document_metadata(doc)
verify_document_content(
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
)
# Should only find test1.docx
assert (
len(found_documents) == 1
), "Should only find one document in the time window"
doc = found_documents[0]
assert doc.semantic_identifier == "test1.docx"
verify_document_metadata(doc)
verify_document_content(
doc,
[d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0],
)
def test_sharepoint_connector_pages(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get documents within the time window
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get documents within the time window
found_documents = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
# Should only find CollabHome
assert len(found_documents) == 1, "Should only find one page"
doc = found_documents[0]
assert doc.semantic_identifier == "CollabHome"
verify_document_metadata(doc)
assert len(doc.sections) == 1
assert (
doc.sections[0].text
== """
# Should only find CollabHome
assert len(found_documents) == 1, "Should only find one page"
doc = found_documents[0]
assert doc.semantic_identifier == "CollabHome"
verify_document_metadata(doc)
assert len(doc.sections) == 1
assert (
doc.sections[0].text
== """
# Home
Display recent news.
@@ -282,4 +341,4 @@ Add a document library
## Document library
""".strip()
)
)

View File

@@ -0,0 +1,113 @@
import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
from onyx.db.enums import AccessType
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
SharepointTestEnvSetupTuple = tuple[
DATestUser, # admin_user
DATestUser, # regular_user_1
DATestUser, # regular_user_2
DATestCredential,
DATestConnector,
DATestCCPair,
]
@pytest.fixture(scope="module")
def sharepoint_test_env_setup() -> Generator[SharepointTestEnvSetupTuple]:
# Reset all data before running the test
reset_all()
# Required environment variables for SharePoint certificate authentication
sp_client_id = os.environ.get("PERM_SYNC_SHAREPOINT_CLIENT_ID")
sp_private_key = os.environ.get("PERM_SYNC_SHAREPOINT_PRIVATE_KEY")
sp_certificate_password = os.environ.get(
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
)
sp_directory_id = os.environ.get("PERM_SYNC_SHAREPOINT_DIRECTORY_ID")
sharepoint_sites = "https://danswerai.sharepoint.com/sites/Permisisonsync"
admin_email = "admin@onyx.app"
user1_email = "subash@onyx.app"
user2_email = "raunak@onyx.app"
if not sp_private_key or not sp_certificate_password or not sp_directory_id:
pytest.skip("Skipping test because required environment variables are not set")
# Certificate-based credentials
credentials = {
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
"sp_client_id": sp_client_id,
"sp_private_key": sp_private_key,
"sp_certificate_password": sp_certificate_password,
"sp_directory_id": sp_directory_id,
}
# Create users
admin_user: DATestUser = UserManager.create(email=admin_email)
regular_user_1: DATestUser = UserManager.create(email=user1_email)
regular_user_2: DATestUser = UserManager.create(email=user2_email)
# Create LLM provider for search functionality
LLMProviderManager.create(user_performing_action=admin_user)
# Create credential
credential: DATestCredential = CredentialManager.create(
source=DocumentSource.SHAREPOINT,
credential_json=credentials,
user_performing_action=admin_user,
)
# Create connector with SharePoint-specific configuration
connector: DATestConnector = ConnectorManager.create(
name="SharePoint Test",
input_type=InputType.POLL,
source=DocumentSource.SHAREPOINT,
connector_specific_config={
"sites": sharepoint_sites.split(","),
},
access_type=AccessType.SYNC, # Enable permission sync
user_performing_action=admin_user,
)
# Create CC pair with permission sync enabled
cc_pair: DATestCCPair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC, # Enable permission sync
user_performing_action=admin_user,
)
# Wait for both indexing and permission sync to complete
before = datetime.now(tz=timezone.utc)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
timeout=float("inf"),
)
# Wait for permission sync completion specifically
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
timeout=float("inf"),
)
yield admin_user, regular_user_1, regular_user_2, credential, connector, cc_pair

View File

@@ -0,0 +1,205 @@
from typing import List
from uuid import UUID
import pytest
from sqlalchemy.orm import Session
from ee.onyx.access.access import _get_access_for_documents
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_email
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.connector_job_tests.sharepoint.conftest import (
SharepointTestEnvSetupTuple,
)
logger = setup_logger()
def get_user_acl(user: User, db_session: Session) -> set[str]:
db_external_groups = (
fetch_external_groups_for_user(db_session, user.id) if user else []
)
prefixed_external_groups = [
prefix_external_group(db_external_group.external_user_group_id)
for db_external_group in db_external_groups
]
user_acl = set(prefixed_external_groups)
user_acl.update({prefix_user_email(user.email), PUBLIC_DOC_PAT})
return user_acl
def get_user_document_access_via_acl(
test_user: DATestUser, document_ids: List[str], db_session: Session
) -> List[str]:
# Get the actual User object from the database
user = fetch_user_by_id(db_session, UUID(test_user.id))
if not user:
logger.error(f"Could not find user with ID {test_user.id}")
return []
user_acl = get_user_acl(user, db_session)
logger.info(f"User {user.email} ACL entries: {user_acl}")
# Get document access information
doc_access_map = _get_access_for_documents(document_ids, db_session)
logger.info(f"Found access info for {len(doc_access_map)} documents")
accessible_docs = []
for doc_id, doc_access in doc_access_map.items():
doc_acl = doc_access.to_acl()
logger.info(f"Document {doc_id} ACL: {doc_acl}")
# Check if user has any matching ACL entry
if user_acl.intersection(doc_acl):
accessible_docs.append(doc_id)
logger.info(f"User {user.email} has access to document {doc_id}")
else:
logger.info(f"User {user.email} does NOT have access to document {doc_id}")
return accessible_docs
def get_all_connector_documents(
cc_pair: DATestCCPair, db_session: Session
) -> List[str]:
from onyx.db.models import DocumentByConnectorCredentialPair
from sqlalchemy import select
stmt = select(DocumentByConnectorCredentialPair.id).where(
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
)
result = db_session.execute(stmt)
document_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(document_ids)} documents for connector {cc_pair.connector_id}"
)
return document_ids
def get_documents_by_permission_type(
document_ids: List[str], db_session: Session
) -> List[str]:
"""
Categorize documents by their permission types
Returns a dictionary with lists of document IDs for each permission type
"""
doc_access_map = _get_access_for_documents(document_ids, db_session)
public_docs = []
for doc_id, doc_access in doc_access_map.items():
if doc_access.is_public:
public_docs.append(doc_id)
return public_docs
def test_public_documents_accessible_by_all_users(
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
) -> None:
"""Test that public documents are accessible by both test users using ACL verification"""
(
admin_user,
regular_user_1,
regular_user_2,
credential,
connector,
cc_pair,
) = sharepoint_test_env_setup
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(cc_pair, db_session)
# Test that regular_user_1 can access documents
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=regular_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
# Test that regular_user_2 can access documents
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=regular_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
# For public documents, both users should have access to at least some docs
assert len(accessible_docs_user1) == 8, (
f"User 1 should have access to documents. Found "
f"{len(accessible_docs_user1)} accessible docs out of "
f"{len(all_document_ids)} total"
)
assert len(accessible_docs_user2) == 1, (
f"User 2 should have access to documents. Found "
f"{len(accessible_docs_user2)} accessible docs out of "
f"{len(all_document_ids)} total"
)
logger.info(
"Successfully verified public documents are accessible by users via ACL"
)
def test_group_based_permissions(
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
) -> None:
"""Test that documents with group permissions are accessible only by users in that group using ACL verification"""
(
admin_user,
regular_user_1,
regular_user_2,
credential,
connector,
cc_pair,
) = sharepoint_test_env_setup
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(cc_pair, db_session)
if not all_document_ids:
pytest.skip("No documents found for connector - skipping test")
# Test access for both users
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=regular_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=regular_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
public_docs = get_documents_by_permission_type(all_document_ids, db_session)
# Check if user 2 has access to any non-public documents
non_public_access_user2 = [
doc for doc in accessible_docs_user2 if doc not in public_docs
]
assert (
len(non_public_access_user2) == 0
), f"User 2 should only have access to public documents. Found access to non-public docs: {non_public_access_user2}"

View File

@@ -8,6 +8,7 @@ import {
useField,
useFormikContext,
} from "formik";
import { FileUpload } from "@/components/admin/connectors/FileUpload";
import * as Yup from "yup";
import { FormBodyBuilder } from "./admin/connectors/types";
import { StringOrNumberOption } from "@/components/Dropdown";
@@ -37,6 +38,12 @@ import { transformLinkUri } from "@/lib/utils";
import FileInput from "@/app/admin/connectors/[connector]/pages/ConnectorInput/FileInput";
import { DatePicker } from "./ui/datePicker";
import { Textarea, TextareaProps } from "./ui/textarea";
import {
TypedFile,
createTypedFile,
getFileTypeDefinitionForField,
FILE_TYPE_DEFINITIONS,
} from "@/lib/connectors/fileTypes";
export function SectionHeader({
children,
@@ -386,6 +393,120 @@ export function FileUploadFormField({
);
}
export function TypedFileUploadFormField({
name,
label,
subtext,
}: {
name: string;
label: string;
subtext?: string | JSX.Element;
}) {
const [field, , helpers] = useField<TypedFile | null>(name);
const [customError, setCustomError] = useState<string>("");
const [isValidating, setIsValidating] = useState(false);
const [description, setDescription] = useState<string>("");
useEffect(() => {
const typeDefinitionKey = getFileTypeDefinitionForField(name);
if (typeDefinitionKey) {
setDescription(
FILE_TYPE_DEFINITIONS[typeDefinitionKey].description || ""
);
}
}, [name]);
useEffect(() => {
const validateFile = async () => {
if (!field.value) {
setIsValidating(false);
return;
}
setIsValidating(true);
try {
const validation = await field.value.validate();
if (validation?.isValid) {
setCustomError("");
} else {
setCustomError(validation?.errors.join(", ") || "Unknown error");
helpers.setValue(null);
}
} catch (error) {
setCustomError(
error instanceof Error ? error.message : "Validation error"
);
helpers.setValue(null);
} finally {
setIsValidating(false);
}
};
validateFile();
}, [field.value, helpers]);
const handleFileSelection = async (files: File[]) => {
if (files.length === 0) {
helpers.setValue(null);
setCustomError("");
return;
}
const file = files[0];
if (!file) {
setCustomError("File selection error");
return;
}
const typeDefinitionKey = getFileTypeDefinitionForField(name);
if (!typeDefinitionKey) {
setCustomError(`No file type definition found for field: ${name}`);
return;
}
try {
const typedFile = createTypedFile(file, name, typeDefinitionKey);
helpers.setValue(typedFile);
setCustomError("");
} catch (error) {
setCustomError(error instanceof Error ? error.message : "Unknown error");
helpers.setValue(null);
} finally {
setIsValidating(false);
}
};
return (
<div className="w-full">
<FieldLabel name={name} label={label} subtext={subtext} />
{description && (
<div className="text-sm text-gray-500 mb-2">{description}</div>
)}
<FileUpload
selectedFiles={field.value ? [field.value.file] : []}
setSelectedFiles={handleFileSelection}
multiple={false}
/>
{/* Validation feedback */}
{isValidating && (
<div className="text-blue-500 text-sm mt-1">Validating file...</div>
)}
{customError ? (
<div className="text-red-500 text-sm mt-1">{customError}</div>
) : (
<ErrorMessage
name={name}
component="div"
className="text-red-500 text-sm mt-1"
/>
)}
</div>
);
}
export function MultiSelectField({
name,
label,

View File

@@ -4,11 +4,20 @@ import * as Yup from "yup";
import { Popup } from "./Popup";
import { ValidSources } from "@/lib/types";
import { createCredential } from "@/lib/credential";
import { CredentialBase, Credential } from "@/lib/connectors/credentials";
import {
createCredential,
createCredentialWithPrivateKey,
} from "@/lib/credential";
import {
CredentialBase,
Credential,
CredentialWithPrivateKey,
} from "@/lib/connectors/credentials";
const PRIVATE_KEY_FIELD_KEY = "private_key";
export async function submitCredential<T>(
credential: CredentialBase<T>
credential: CredentialBase<T> | CredentialWithPrivateKey<T>
): Promise<{
credential?: Credential<any>;
message: string;
@@ -16,8 +25,14 @@ export async function submitCredential<T>(
}> {
let isSuccess = false;
try {
const response = await createCredential(credential);
let response: Response;
if (PRIVATE_KEY_FIELD_KEY in credential && credential.private_key) {
response = await createCredentialWithPrivateKey(
credential as CredentialWithPrivateKey<T>
);
} else {
response = await createCredential(credential as CredentialBase<T>);
}
if (response.ok) {
const parsed_response = await response.json();
const credential = parsed_response.credential;

View File

@@ -10,6 +10,7 @@ import {
deleteCredential,
swapCredential,
updateCredential,
updateCredentialWithPrivateKey,
} from "@/lib/credential";
import { usePopup } from "@/components/admin/connectors/Popup";
import CreateCredential from "./actions/CreateCredential";
@@ -34,6 +35,7 @@ import {
import { Spinner } from "@/components/Spinner";
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
import { Card } from "../ui/card";
import { isTypedFileField, TypedFile } from "@/lib/connectors/fileTypes";
export default function CredentialSection({
ccPair,
@@ -111,7 +113,23 @@ export default function CredentialSection({
details: any,
onSucces: () => void
) => {
const response = await updateCredential(selectedCredential.id, details);
let privateKey: TypedFile | null = null;
Object.entries(details).forEach(([key, value]) => {
if (isTypedFileField(key)) {
privateKey = value as TypedFile;
delete details[key];
}
});
let response;
if (privateKey) {
response = await updateCredentialWithPrivateKey(
selectedCredential.id,
details,
privateKey
);
} else {
response = await updateCredential(selectedCredential.id, details);
}
if (response.ok) {
setPopup({
message: "Updated credential",

View File

@@ -23,6 +23,7 @@ import {
import { useUser } from "@/components/user/UserProvider";
import CardSection from "@/components/admin/CardSection";
import { CredentialFieldsRenderer } from "./CredentialFieldsRenderer";
import { TypedFile } from "@/lib/connectors/fileTypes";
const CreateButton = ({
onClick,
@@ -114,10 +115,15 @@ export default function CreateCredential({
const { name, is_public, groups, ...credentialValues } = values;
let privateKey: TypedFile | null = null;
const filteredCredentialValues = Object.fromEntries(
Object.entries(credentialValues).filter(
([_, value]) => value !== null && value !== ""
)
Object.entries(credentialValues).filter(([key, value]) => {
if (value instanceof TypedFile) {
privateKey = value;
return false;
}
return value !== null && value !== "";
})
);
try {
@@ -128,6 +134,7 @@ export default function CreateCredential({
groups: groups,
name: name,
source: sourceType,
private_key: privateKey || undefined,
});
const { message, isSuccess, credential } = response;

View File

@@ -1,12 +1,17 @@
import React from "react";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { useFormikContext } from "formik";
import { BooleanFormField, TextFormField } from "@/components/Field";
import {
BooleanFormField,
TextFormField,
TypedFileUploadFormField,
} from "@/components/Field";
import {
getDisplayNameForCredentialKey,
CredentialTemplateWithAuth,
} from "@/lib/connectors/credentials";
import { dictionaryType } from "../types";
import { isTypedFileField } from "@/lib/connectors/fileTypes";
interface CredentialFieldsRendererProps {
credentialTemplate: dictionaryType;
@@ -90,6 +95,16 @@ export function CredentialFieldsRenderer({
)}
{Object.entries(method.fields).map(([key, val]) => {
if (isTypedFileField(key)) {
return (
<TypedFileUploadFormField
key={key}
name={key}
label={getDisplayNameForCredentialKey(key)}
/>
);
}
if (typeof val === "boolean") {
return (
<BooleanFormField
@@ -130,6 +145,15 @@ export function CredentialFieldsRenderer({
if (key === "authentication_method" || key === "authMethods") {
return null;
}
if (isTypedFileField(key)) {
return (
<TypedFileUploadFormField
key={key}
name={key}
label={getDisplayNameForCredentialKey(key)}
/>
);
}
if (typeof val === "boolean") {
return (
@@ -144,7 +168,7 @@ export function CredentialFieldsRenderer({
<TextFormField
key={key}
name={key}
placeholder={val}
placeholder={val as string}
label={getDisplayNameForCredentialKey(key)}
type={
key.toLowerCase().includes("token") ||

View File

@@ -3,7 +3,7 @@ import { Button } from "@/components/ui/button";
import Text from "@/components/ui/text";
import { FaNewspaper, FaTrash } from "react-icons/fa";
import { TextFormField } from "@/components/Field";
import { TextFormField, TypedFileUploadFormField } from "@/components/Field";
import { Form, Formik, FormikHelpers } from "formik";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import {
@@ -12,6 +12,7 @@ import {
} from "@/lib/connectors/credentials";
import { createEditingValidationSchema, createInitialValues } from "../lib";
import { dictionaryType, formType } from "../types";
import { isTypedFileField } from "@/lib/connectors/fileTypes";
const EditCredential = ({
credential,
@@ -68,22 +69,30 @@ const EditCredential = ({
label="Name (optional):"
/>
{Object.entries(credential.credential_json).map(([key, value]) => (
<TextFormField
includeRevert
key={key}
name={key}
placeholder={value}
label={getDisplayNameForCredentialKey(key)}
type={
key.toLowerCase().includes("token") ||
key.toLowerCase().includes("password")
? "password"
: "text"
}
disabled={key === "authentication_method"}
/>
))}
{Object.entries(credential.credential_json).map(([key, value]) =>
isTypedFileField(key) ? (
<TypedFileUploadFormField
key={key}
name={key}
label={getDisplayNameForCredentialKey(key)}
/>
) : (
<TextFormField
includeRevert
key={key}
name={key}
placeholder={value as string}
label={getDisplayNameForCredentialKey(key)}
type={
key.toLowerCase().includes("token") ||
key.toLowerCase().includes("password")
? "password"
: "text"
}
disabled={key === "authentication_method"}
/>
)
)}
<div className="flex justify-between w-full">
<Button type="button" onClick={() => resetForm()}>
<div className="flex gap-x-2 items-center w-full border-none">

View File

@@ -6,6 +6,7 @@ import {
getDisplayNameForCredentialKey,
CredentialTemplateWithAuth,
} from "@/lib/connectors/credentials";
import { isTypedFileField } from "@/lib/connectors/fileTypes";
export function createValidationSchema(json_values: Record<string, any>) {
const schemaFields: Record<string, Yup.AnySchema> = {};
@@ -16,7 +17,6 @@ export function createValidationSchema(json_values: Record<string, any>) {
schemaFields["authentication_method"] = Yup.string().required(
"Please select an authentication method"
);
// conditional rules per authMethod
template.authMethods.forEach((method) => {
Object.entries(method.fields).forEach(([key, def]) => {
@@ -26,6 +26,14 @@ export function createValidationSchema(json_values: Record<string, any>) {
.nullable()
.default(false)
.transform((v, o) => (o === undefined ? false : v));
} else if (isTypedFileField(key)) {
//TypedFile fields - use mixed schema instead of string (check before null check)
schemaFields[key] = Yup.mixed().when("authentication_method", {
is: method.value,
then: () =>
Yup.mixed().required(`Please select a ${displayName} file`),
otherwise: () => Yup.mixed().notRequired(),
});
} else if (def === null) {
schemaFields[key] = Yup.string()
.trim()
@@ -58,6 +66,11 @@ export function createValidationSchema(json_values: Record<string, any>) {
.nullable()
.default(false)
.transform((v, o) => (o === undefined ? false : v));
} else if (isTypedFileField(key)) {
// TypedFile fields - use mixed schema instead of string (check before null check)
schemaFields[key] = Yup.mixed().required(
`Please select a ${displayName} file`
);
} else if (def === null) {
schemaFields[key] = Yup.string()
.trim()
@@ -77,11 +90,16 @@ export function createValidationSchema(json_values: Record<string, any>) {
}
export function createEditingValidationSchema(json_values: dictionaryType) {
const schemaFields: { [key: string]: Yup.StringSchema } = {};
const schemaFields: { [key: string]: Yup.AnySchema } = {};
for (const key in json_values) {
if (Object.prototype.hasOwnProperty.call(json_values, key)) {
schemaFields[key] = Yup.string().optional();
if (isTypedFileField(key)) {
// TypedFile fields - use mixed schema for optional file uploads during editing
schemaFields[key] = Yup.mixed().optional();
} else {
schemaFields[key] = Yup.string().optional();
}
}
}
@@ -95,7 +113,12 @@ export function createInitialValues(credential: Credential<any>): formType {
};
for (const key in credential.credential_json) {
initialValues[key] = "";
// Initialize TypedFile fields as null, other fields as empty strings
if (isTypedFileField(key)) {
initialValues[key] = null as any; // TypedFile fields start as null
} else {
initialValues[key] = "";
}
}
return initialValues;

View File

@@ -1,5 +1,7 @@
import { TypedFile } from "@/lib/connectors/fileTypes";
export interface dictionaryType {
[key: string]: string;
[key: string]: string | TypedFile;
}
export interface formType extends dictionaryType {
name: string;

View File

@@ -17,4 +17,5 @@ export const autoSyncConfigBySource: Record<
github: {},
slack: {},
salesforce: {},
sharepoint: {},
};

View File

@@ -1,4 +1,5 @@
import { ValidSources } from "../types";
import { TypedFile } from "./fileTypes";
export interface OAuthAdditionalKwargDescription {
name: string;
@@ -30,6 +31,10 @@ export interface CredentialBase<T> {
groups?: number[];
}
export interface CredentialWithPrivateKey<T> extends CredentialBase<T> {
private_key: TypedFile;
}
export interface Credential<T> extends CredentialBase<T> {
id: number;
user_id: string | null;
@@ -188,8 +193,10 @@ export interface SalesforceCredentialJson {
export interface SharepointCredentialJson {
sp_client_id: string;
sp_client_secret: string;
sp_client_secret?: string;
sp_directory_id: string;
sp_certificate_password?: string;
sp_private_key?: TypedFile;
}
export interface AsanaCredentialJson {
@@ -297,10 +304,33 @@ export const credentialTemplates: Record<ValidSources, any> = {
is_sandbox: false,
} as SalesforceCredentialJson,
sharepoint: {
sp_client_id: "",
sp_client_secret: "",
sp_directory_id: "",
} as SharepointCredentialJson,
authentication_method: "client_credentials",
authMethods: [
{
value: "client_secret",
label: "Client Secret",
fields: {
sp_client_id: "",
sp_client_secret: "",
sp_directory_id: "",
},
description:
"If you select this mode, the SharePoint connector will use a client secret to authenticate. You will need to provide the client ID and client secret.",
},
{
value: "certificate",
label: "Certificate Authentication",
fields: {
sp_client_id: "",
sp_directory_id: "",
sp_certificate_password: "",
sp_private_key: null,
},
description:
"If you select this mode, the SharePoint connector will use a certificate to authenticate. You will need to provide the client ID, directory ID, certificate password, and PFX data.",
},
],
} as CredentialTemplateWithAuth<SharepointCredentialJson>,
asana: {
asana_api_token_secret: "",
} as AsanaCredentialJson,
@@ -522,6 +552,8 @@ export const credentialDisplayNames: Record<string, string> = {
sp_client_id: "SharePoint Client ID",
sp_client_secret: "SharePoint Client Secret",
sp_directory_id: "SharePoint Directory ID",
sp_certificate_password: "SharePoint Certificate Password",
sp_private_key: "SharePoint Private Key",
// Asana
asana_api_token_secret: "Asana API Token",

View File

@@ -0,0 +1,117 @@
export enum FileTypeCategory {
SHAREPOINT_PFX_FILE = "sharepoint_pfx_file",
}
export interface FileValidationRule {
maxSizeKB?: number;
allowedExtensions?: string[];
contentValidation?: (file: File) => Promise<boolean>;
}
export interface FileTypeDefinition {
category: FileTypeCategory;
validation?: FileValidationRule;
description?: string;
}
export const FILE_TYPE_DEFINITIONS: Record<
FileTypeCategory,
FileTypeDefinition
> = {
[FileTypeCategory.SHAREPOINT_PFX_FILE]: {
category: FileTypeCategory.SHAREPOINT_PFX_FILE,
validation: {
maxSizeKB: 10,
allowedExtensions: [".pfx"],
},
description:
"Please upload a .pfx file containing the private key for SharePoint. The file size must be under 10KB.",
},
};
export class TypedFile {
constructor(
public readonly file: File,
public readonly typeDefinition: FileTypeDefinition,
public readonly fieldKey: string
) {}
async validate(): Promise<{ isValid: boolean; errors: string[] }> {
const errors: string[] = [];
const { validation } = this.typeDefinition;
if (!validation) {
return {
isValid: true,
errors: [],
};
}
// Size validation
if (validation.maxSizeKB && this.file.size > validation.maxSizeKB * 1024) {
errors.push(`File size must not exceed ${validation.maxSizeKB}KB`);
}
// Extension validation
if (validation.allowedExtensions) {
const extension = this.file.name.toLowerCase().split(".").pop();
if (
!extension ||
!validation.allowedExtensions.includes(`.${extension}`)
) {
errors.push(
`File must have one of these extensions: ${validation.allowedExtensions.join(", ")}`
);
}
}
// Content validation
if (validation.contentValidation) {
try {
const isContentValid = await validation.contentValidation(this.file);
if (!isContentValid) {
errors.push(`File content validation failed`);
}
} catch (error) {
errors.push(
`Content validation error: ${error instanceof Error ? error.message : "Unknown error"}`
);
}
}
return {
isValid: errors.length === 0,
errors,
};
}
}
export function createTypedFile(
file: File,
fieldKey: string,
typeDefinitionKey: FileTypeCategory
): TypedFile {
const typeDefinition = FILE_TYPE_DEFINITIONS[typeDefinitionKey];
if (!typeDefinition) {
throw new Error(`Unknown file type definition: ${typeDefinitionKey}`);
}
return new TypedFile(file, typeDefinition, fieldKey);
}
export function isTypedFileField(fieldKey: string): boolean {
// Define which fields should be typed files
const typedFileFields = new Set(["sp_private_key"]);
return typedFileFields.has(fieldKey);
}
// Get the appropriate file type definition for a field
export function getFileTypeDefinitionForField(
fieldKey: string
): FileTypeCategory | null {
const fieldToTypeMap: Record<string, FileTypeCategory> = {
sp_private_key: FileTypeCategory.SHAREPOINT_PFX_FILE,
};
return fieldToTypeMap[fieldKey] || null;
}

View File

@@ -108,3 +108,11 @@ export const ALLOWED_URL_PROTOCOLS = [
];
export const MAX_CHARACTERS_PERSONA_DESCRIPTION = 5000000;
//Credential form data key constants
export const CREDENTIAL_NAME = "name";
export const CREDENTIAL_SOURCE = "source";
export const CREDENTIAL_UPLOADED_FILE = "uploaded_file";
export const CREDENTIAL_FIELD_KEY = "field_key";
export const CREDENTIAL_TYPE_DEFINITION_KEY = "type_definition_key";
export const CREDENTIAL_JSON = "credential_json";

View File

@@ -1,5 +1,17 @@
import { CredentialBase } from "./connectors/credentials";
import {
CredentialBase,
CredentialWithPrivateKey,
} from "./connectors/credentials";
import { AccessType } from "@/lib/types";
import { TypedFile } from "./connectors/fileTypes";
import {
CREDENTIAL_NAME,
CREDENTIAL_SOURCE,
CREDENTIAL_UPLOADED_FILE,
CREDENTIAL_FIELD_KEY,
CREDENTIAL_TYPE_DEFINITION_KEY,
CREDENTIAL_JSON,
} from "./constants";
export async function createCredential(credential: CredentialBase<any>) {
return await fetch(`/api/manage/credential`, {
@@ -11,6 +23,37 @@ export async function createCredential(credential: CredentialBase<any>) {
});
}
export async function createCredentialWithPrivateKey(
credential: CredentialWithPrivateKey<any>
) {
const formData = new FormData();
formData.append(CREDENTIAL_JSON, JSON.stringify(credential.credential_json));
formData.append("admin_public", credential.admin_public.toString());
formData.append(
"curator_public",
credential.curator_public?.toString() || "false"
);
if (credential.groups && credential.groups.length > 0) {
credential.groups.forEach((group) => {
formData.append("groups", String(group));
});
}
formData.append(CREDENTIAL_NAME, credential.name || "");
formData.append(CREDENTIAL_SOURCE, credential.source);
if (credential.private_key) {
formData.append(CREDENTIAL_UPLOADED_FILE, credential.private_key.file);
formData.append(CREDENTIAL_FIELD_KEY, credential.private_key.fieldKey);
formData.append(
CREDENTIAL_TYPE_DEFINITION_KEY,
credential.private_key.typeDefinition.category
);
}
return await fetch(`/api/manage/credential/private-key`, {
method: "POST",
body: formData,
});
}
export async function adminDeleteCredential<T>(credentialId: number) {
return await fetch(`/api/manage/admin/credential/${credentialId}`, {
method: "DELETE",
@@ -70,7 +113,7 @@ export function updateCredential(credentialId: number, newDetails: any) {
const name = newDetails.name;
const details = Object.fromEntries(
Object.entries(newDetails).filter(
([key, value]) => key !== "name" && value !== ""
([key, value]) => key !== CREDENTIAL_NAME && value !== ""
)
);
return fetch(`/api/manage/admin/credential/${credentialId}`, {
@@ -85,6 +128,32 @@ export function updateCredential(credentialId: number, newDetails: any) {
});
}
export function updateCredentialWithPrivateKey(
credentialId: number,
newDetails: any,
privateKey: TypedFile
) {
const name = newDetails.name;
const details = Object.fromEntries(
Object.entries(newDetails).filter(
([key, value]) => key !== CREDENTIAL_NAME && value !== ""
)
);
const formData = new FormData();
formData.append(CREDENTIAL_NAME, name);
formData.append(CREDENTIAL_JSON, JSON.stringify(details));
formData.append(CREDENTIAL_UPLOADED_FILE, privateKey.file);
formData.append(CREDENTIAL_FIELD_KEY, privateKey.fieldKey);
formData.append(
CREDENTIAL_TYPE_DEFINITION_KEY,
privateKey.typeDefinition.category
);
return fetch(`/api/manage/admin/credential/private-key/${credentialId}`, {
method: "PUT",
body: formData,
});
}
export function swapCredential(
newCredentialId: number,
connectorId: number,

View File

@@ -456,6 +456,7 @@ export const validAutoSyncSources = [
ValidSources.Slack,
ValidSources.Salesforce,
ValidSources.GitHub,
ValidSources.Sharepoint,
] as const;
// Create a type from the array elements