Compare commits

...

14 Commits

Author SHA1 Message Date
Evan Lohn
d972fb08fc feat: configurable sharepoint endpoints 2026-02-19 15:19:27 -08:00
Evan Lohn
4433d130db test 2026-02-19 14:32:29 -08:00
Evan Lohn
f108232415 feat: azure ad group pagination 2026-02-19 14:32:29 -08:00
Evan Lohn
6cd7c59a1c feat: sharepoint scalability 3 2026-02-19 14:32:26 -08:00
Evan Lohn
d31d8092ce nit 2026-02-19 14:32:23 -08:00
Evan Lohn
2cdeecc844 feat: delta sync sharepoint
Co-authored-by: CE11-Kishan <CE11-Kishan@users.noreply.github.com>
2026-02-19 14:32:23 -08:00
Evan Lohn
8f910a187b shouldve trusted claude 2026-02-19 14:32:19 -08:00
Evan Lohn
cfe5b95cc4 pr comments and fixes 2026-02-19 14:32:19 -08:00
Evan Lohn
760c4fae6a more test fixes 2026-02-19 14:32:19 -08:00
Evan Lohn
b769ee530f fix test 2026-02-19 14:32:19 -08:00
Evan Lohn
a42114a932 more comments 2026-02-19 14:32:18 -08:00
Evan Lohn
e709a9dd0e address pr comments 2026-02-19 14:32:18 -08:00
Evan Lohn
e1c16c2391 pr comments 2026-02-19 14:32:18 -08:00
Evan Lohn
9ff29b8879 feat: sharepoint scalability 1 2026-02-19 14:32:18 -08:00
9 changed files with 1404 additions and 497 deletions

View File

@@ -1,9 +1,13 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -43,14 +47,27 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
# Process each site
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = connector._create_rest_client_context(site_descriptor.url)
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Yield each group
for group in external_groups:

View File

@@ -1,9 +1,13 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -14,7 +18,10 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -33,6 +40,70 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -572,8 +643,65 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext, graph_client: GraphClient
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -629,57 +757,22 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
)
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
return external_user_groups
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
return external_user_groups

View File

@@ -625,6 +625,14 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

File diff suppressed because it is too large Load Diff

View File

@@ -50,12 +50,15 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -63,11 +66,15 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
# impls for BaseConnector
@@ -76,7 +83,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
authority_url = f"{self.authority_host}/{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -91,7 +98,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
scopes=[f"{self.graph_api_host}/.default"]
)
if not isinstance(token, dict):

View File

@@ -25,6 +25,7 @@ class ExpectedDocument:
content: str
folder_path: str | None = None
library: str = "Shared Documents" # Default to main library
expected_link_substrings: list[str] | None = None
EXPECTED_DOCUMENTS = [
@@ -32,22 +33,29 @@ EXPECTED_DOCUMENTS = [
semantic_identifier="test1.docx",
content="test1",
folder_path="test",
expected_link_substrings=["_layouts/15/Doc.aspx", "file=test1.docx"],
),
ExpectedDocument(
semantic_identifier="test2.docx",
content="test2",
folder_path="test/nested with spaces",
expected_link_substrings=["_layouts/15/Doc.aspx", "file=test2.docx"],
),
ExpectedDocument(
semantic_identifier="should-not-index-on-specific-folder.docx",
content="should-not-index-on-specific-folder",
folder_path=None, # root folder
expected_link_substrings=[
"_layouts/15/Doc.aspx",
"file=should-not-index-on-specific-folder.docx",
],
),
ExpectedDocument(
semantic_identifier="other.docx",
content="other",
folder_path=None,
library="Other Library",
expected_link_substrings=["_layouts/15/Doc.aspx", "file=other.docx"],
),
]
@@ -61,11 +69,13 @@ EXPECTED_PAGES = [
"Add a document library\n\n## Document library"
),
folder_path=None,
expected_link_substrings=["SitePages/CollabHome.aspx"],
),
ExpectedDocument(
semantic_identifier="Home",
content="# Home",
folder_path=None,
expected_link_substrings=["SitePages/Home.aspx"],
),
]
@@ -88,6 +98,20 @@ def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
assert len(doc.sections) == 1
assert doc.sections[0].text is not None
assert expected.content == doc.sections[0].text
if expected.expected_link_substrings is not None:
actual_link = doc.sections[0].link
assert actual_link is not None, (
f"Expected section link containing {expected.expected_link_substrings} "
f"for '{expected.semantic_identifier}', but link was None"
)
for substr in expected.expected_link_substrings:
assert substr in actual_link, (
f"Section link for '{expected.semantic_identifier}' "
f"missing expected substring '{substr}', "
f"actual link: '{actual_link}'"
)
verify_document_metadata(doc)

View File

@@ -0,0 +1,252 @@
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_enumerate_ad_groups_paginated,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_iter_graph_collection,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_normalize_email,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
AD_GROUP_ENUMERATION_THRESHOLD,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import GroupsResult
MODULE = "ee.onyx.external_permissions.sharepoint.permission_utils"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fake_token() -> str:
return "fake-token"
def _make_graph_page(
items: list[dict[str, Any]],
next_link: str | None = None,
) -> dict[str, Any]:
page: dict[str, Any] = {"value": items}
if next_link:
page["@odata.nextLink"] = next_link
return page
# ---------------------------------------------------------------------------
# _normalize_email
# ---------------------------------------------------------------------------
def test_normalize_email_strips_onmicrosoft() -> None:
assert _normalize_email("user@contoso.onmicrosoft.com") == "user@contoso.com"
def test_normalize_email_noop_for_normal_domain() -> None:
assert _normalize_email("user@contoso.com") == "user@contoso.com"
# ---------------------------------------------------------------------------
# _iter_graph_collection
# ---------------------------------------------------------------------------
@patch(f"{MODULE}._graph_api_get")
def test_iter_graph_collection_single_page(mock_get: MagicMock) -> None:
mock_get.return_value = _make_graph_page([{"id": "1"}, {"id": "2"}])
items = list(_iter_graph_collection("https://graph/items", _fake_token))
assert items == [{"id": "1"}, {"id": "2"}]
mock_get.assert_called_once()
@patch(f"{MODULE}._graph_api_get")
def test_iter_graph_collection_multi_page(mock_get: MagicMock) -> None:
mock_get.side_effect = [
_make_graph_page([{"id": "1"}], next_link="https://graph/items?page=2"),
_make_graph_page([{"id": "2"}]),
]
items = list(_iter_graph_collection("https://graph/items", _fake_token))
assert items == [{"id": "1"}, {"id": "2"}]
assert mock_get.call_count == 2
@patch(f"{MODULE}._graph_api_get")
def test_iter_graph_collection_empty(mock_get: MagicMock) -> None:
mock_get.return_value = _make_graph_page([])
assert list(_iter_graph_collection("https://graph/items", _fake_token)) == []
# ---------------------------------------------------------------------------
# _enumerate_ad_groups_paginated
# ---------------------------------------------------------------------------
def _mock_graph_get_for_enumeration(
groups: list[dict[str, Any]],
members_by_group: dict[str, list[dict[str, Any]]],
) -> Generator[dict[str, Any], None, None]:
"""Return a side_effect function for _graph_api_get that serves
groups on the /groups URL and members on /groups/{id}/members URLs."""
def side_effect(
url: str,
get_access_token: Any, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
if "/members" in url:
group_id = url.split("/groups/")[1].split("/members")[0]
return _make_graph_page(members_by_group.get(group_id, []))
return _make_graph_page(groups)
return side_effect # type: ignore[return-value]
@patch(f"{MODULE}._graph_api_get")
def test_enumerate_ad_groups_yields_groups(mock_get: MagicMock) -> None:
groups = [
{"id": "g1", "displayName": "Engineering"},
{"id": "g2", "displayName": "Marketing"},
]
members = {
"g1": [{"userPrincipalName": "alice@contoso.com"}],
"g2": [{"mail": "bob@contoso.onmicrosoft.com"}],
}
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, members)
results = list(_enumerate_ad_groups_paginated(_fake_token, already_resolved=set()))
assert len(results) == 2
eng = next(r for r in results if r.id == "Engineering_g1")
assert eng.user_emails == ["alice@contoso.com"]
mkt = next(r for r in results if r.id == "Marketing_g2")
assert mkt.user_emails == ["bob@contoso.com"]
@patch(f"{MODULE}._graph_api_get")
def test_enumerate_ad_groups_skips_already_resolved(mock_get: MagicMock) -> None:
groups = [{"id": "g1", "displayName": "Engineering"}]
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
results = list(
_enumerate_ad_groups_paginated(_fake_token, already_resolved={"Engineering_g1"})
)
assert results == []
@patch(f"{MODULE}._graph_api_get")
def test_enumerate_ad_groups_circuit_breaker(mock_get: MagicMock) -> None:
"""Enumeration stops after AD_GROUP_ENUMERATION_THRESHOLD groups."""
over_limit = AD_GROUP_ENUMERATION_THRESHOLD + 5
groups = [{"id": f"g{i}", "displayName": f"Group{i}"} for i in range(over_limit)]
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
results = list(_enumerate_ad_groups_paginated(_fake_token, already_resolved=set()))
assert len(results) <= AD_GROUP_ENUMERATION_THRESHOLD
# ---------------------------------------------------------------------------
# get_sharepoint_external_groups
# ---------------------------------------------------------------------------
def _stub_role_assignment_resolution(
groups_to_emails: dict[str, set[str]],
) -> tuple[MagicMock, MagicMock]:
"""Return (mock_sleep_and_retry, mock_recursive) pre-configured to
simulate role-assignment group resolution."""
mock_sleep = MagicMock()
mock_recursive = MagicMock(
return_value=GroupsResult(
groups_to_emails=groups_to_emails,
found_public_group=False,
)
)
return mock_sleep, mock_recursive
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_default_skips_ad_enumeration(
mock_sleep: MagicMock, mock_recursive: MagicMock # noqa: ARG001
) -> None:
mock_recursive.return_value = GroupsResult(
groups_to_emails={"SiteGroup_abc": {"alice@contoso.com"}},
found_public_group=False,
)
results = get_sharepoint_external_groups(
client_context=MagicMock(),
graph_client=MagicMock(),
)
assert len(results) == 1
assert results[0].id == "SiteGroup_abc"
assert results[0].user_emails == ["alice@contoso.com"]
@patch(f"{MODULE}._enumerate_ad_groups_paginated")
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_enumerate_all_includes_ad_groups(
mock_sleep: MagicMock, # noqa: ARG001
mock_recursive: MagicMock,
mock_enum: MagicMock,
) -> None:
from ee.onyx.db.external_perm import ExternalUserGroup
mock_recursive.return_value = GroupsResult(
groups_to_emails={"SiteGroup_abc": {"alice@contoso.com"}},
found_public_group=False,
)
mock_enum.return_value = [
ExternalUserGroup(id="ADGroup_xyz", user_emails=["bob@contoso.com"]),
]
results = get_sharepoint_external_groups(
client_context=MagicMock(),
graph_client=MagicMock(),
get_access_token=_fake_token,
enumerate_all_ad_groups=True,
)
assert len(results) == 2
ids = {r.id for r in results}
assert ids == {"SiteGroup_abc", "ADGroup_xyz"}
mock_enum.assert_called_once()
@patch(f"{MODULE}._enumerate_ad_groups_paginated")
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_enumerate_all_without_token_skips(
mock_sleep: MagicMock, # noqa: ARG001
mock_recursive: MagicMock,
mock_enum: MagicMock,
) -> None:
"""Even if enumerate_all_ad_groups=True, no token means skip."""
mock_recursive.return_value = GroupsResult(
groups_to_emails={},
found_public_group=False,
)
results = get_sharepoint_external_groups(
client_context=MagicMock(),
graph_client=MagicMock(),
get_access_token=None,
enumerate_all_ad_groups=True,
)
assert results == []
mock_enum.assert_not_called()

View File

@@ -1,13 +1,16 @@
from __future__ import annotations
from collections import deque
from collections.abc import Generator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from types import SimpleNamespace
from typing import Any
import pytest
from onyx.connectors.sharepoint.connector import DriveItemData
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.sharepoint.connector import SharepointConnectorCheckpoint
@@ -22,24 +25,10 @@ class _FakeQuery:
return self._payload
class _FakeFolder:
def __init__(self, items: Sequence[Any]) -> None:
self._items = items
self.name = "root"
def get_by_path(self, _path: str) -> _FakeFolder:
return self
def get_files(
self, *, recursive: bool, page_size: int # noqa: ARG002
) -> _FakeQuery:
return _FakeQuery(self._items)
class _FakeDrive:
def __init__(self, name: str, items: Sequence[Any]) -> None:
def __init__(self, name: str) -> None:
self.name = name
self.root = _FakeFolder(items)
self.id = f"fake-drive-id-{name}"
self.web_url = f"https://example.sharepoint.com/sites/sample/{name}"
@@ -69,12 +58,42 @@ class _FakeGraphClient:
self.sites = _FakeSites(drives)
_SAMPLE_ITEM = DriveItemData(
id="item-1",
name="sample.pdf",
web_url="https://example.sharepoint.com/sites/sample/sample.pdf",
parent_reference_path=None,
drive_id="fake-drive-id",
)
def _build_connector(drives: Sequence[_FakeDrive]) -> SharepointConnector:
connector = SharepointConnector()
connector._graph_client = _FakeGraphClient(drives)
return connector
def _fake_iter_drive_items_paged(
self: SharepointConnector, # noqa: ARG001
drive_id: str, # noqa: ARG001
folder_path: str | None = None, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> Generator[DriveItemData, None, None]:
yield _SAMPLE_ITEM
def _fake_iter_drive_items_delta(
self: SharepointConnector, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> Generator[DriveItemData, None, None]:
yield _SAMPLE_ITEM
@pytest.mark.parametrize(
("requested_drive_name", "graph_drive_name"),
[
@@ -84,21 +103,28 @@ def _build_connector(drives: Sequence[_FakeDrive]) -> SharepointConnector:
],
)
def test_fetch_driveitems_matches_international_drive_names(
requested_drive_name: str, graph_drive_name: str
requested_drive_name: str,
graph_drive_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
item = SimpleNamespace(parent_reference=SimpleNamespace(path=None))
connector = _build_connector([_FakeDrive(graph_drive_name, [item])])
connector = _build_connector([_FakeDrive(graph_drive_name)])
site_descriptor = SiteDescriptor(
url="https://example.sharepoint.com/sites/sample",
drive_name=requested_drive_name,
folder_path=None,
)
results = connector._fetch_driveitems(site_descriptor=site_descriptor)
monkeypatch.setattr(
SharepointConnector,
"_iter_drive_items_delta",
_fake_iter_drive_items_delta,
)
results = list(connector._fetch_driveitems(site_descriptor=site_descriptor))
assert len(results) == 1
drive_item, returned_drive_name, drive_web_url = results[0]
assert drive_item is item
assert drive_item.id == _SAMPLE_ITEM.id
assert returned_drive_name == requested_drive_name
assert drive_web_url is not None
@@ -111,25 +137,32 @@ def test_fetch_driveitems_matches_international_drive_names(
("Documentos compartidos", "Documentos"),
],
)
def test_get_drive_items_for_drive_name_matches_map(
requested_drive_name: str, graph_drive_name: str
def test_get_drive_items_for_drive_id_matches_map(
requested_drive_name: str,
graph_drive_name: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
item = SimpleNamespace()
connector = _build_connector([_FakeDrive(graph_drive_name, [item])])
connector = _build_connector([_FakeDrive(graph_drive_name)])
site_descriptor = SiteDescriptor(
url="https://example.sharepoint.com/sites/sample",
drive_name=requested_drive_name,
folder_path=None,
)
results, drive_web_url = connector._get_drive_items_for_drive_name(
site_descriptor=site_descriptor,
drive_name=requested_drive_name,
monkeypatch.setattr(
SharepointConnector,
"_iter_drive_items_delta",
_fake_iter_drive_items_delta,
)
items_iter = connector._get_drive_items_for_drive_id(
site_descriptor=site_descriptor,
drive_id="fake-drive-id",
)
results = list(items_iter)
assert len(results) == 1
assert results[0] is item
assert drive_web_url is not None
assert results[0].id == _SAMPLE_ITEM.id
def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -138,46 +171,68 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
connector.include_site_pages = False
captured_drive_names: list[str] = []
sample_item = DriveItemData(
id="doc-1",
name="sample.pdf",
web_url="https://example.sharepoint.com/sites/sample/sample.pdf",
parent_reference_path=None,
drive_id="fake-drive-id",
)
def fake_resolve_drive(
self: SharepointConnector, # noqa: ARG001
site_descriptor: SiteDescriptor, # noqa: ARG001
drive_name: str,
) -> tuple[str, str | None]:
assert drive_name == "Documents"
return (
"fake-drive-id",
"https://example.sharepoint.com/sites/sample/Documents",
)
def fake_get_drive_items(
self: SharepointConnector, # noqa: ARG001
site_descriptor: SiteDescriptor, # noqa: ARG001
drive_name: str,
drive_id: str, # noqa: ARG001
start: datetime | None, # noqa: ARG001
end: datetime | None, # noqa: ARG001
) -> tuple[list[SimpleNamespace], str | None]:
assert drive_name == "Documents"
return (
[
SimpleNamespace(
name="sample.pdf",
web_url="https://example.sharepoint.com/sites/sample/sample.pdf",
parent_reference=SimpleNamespace(path=None),
)
],
"https://example.sharepoint.com/sites/sample/Documents",
)
) -> Generator[DriveItemData, None, None]:
yield sample_item
def fake_convert(
driveitem: SimpleNamespace, # noqa: ARG001
driveitem: DriveItemData, # noqa: ARG001
drive_name: str,
ctx: Any, # noqa: ARG001
graph_client: Any, # noqa: ARG001
include_permissions: bool, # noqa: ARG001
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
access_token: str | None = None, # noqa: ARG001
) -> SimpleNamespace:
captured_drive_names.append(drive_name)
return SimpleNamespace(sections=["content"])
def fake_get_access_token(self: SharepointConnector) -> str: # noqa: ARG001
return "fake-access-token"
monkeypatch.setattr(
SharepointConnector,
"_get_drive_items_for_drive_name",
"_resolve_drive",
fake_resolve_drive,
)
monkeypatch.setattr(
SharepointConnector,
"_get_drive_items_for_drive_id",
fake_get_drive_items,
)
monkeypatch.setattr(
"onyx.connectors.sharepoint.connector._convert_driveitem_to_document_with_permissions",
fake_convert,
)
monkeypatch.setattr(
SharepointConnector,
"_get_graph_access_token",
fake_get_access_token,
)
checkpoint = SharepointConnectorCheckpoint(has_more=True)
checkpoint.cached_site_descriptors = deque()
@@ -204,7 +259,6 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
except StopIteration:
pass
# Filter out hierarchy nodes (which are also yielded now)
from onyx.connectors.models import HierarchyNode
documents = [item for item in all_yielded if not isinstance(item, HierarchyNode)]
@@ -212,5 +266,248 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
assert len(documents) == 1
assert captured_drive_names == [SHARED_DOCUMENTS_MAP["Documents"]]
# Verify a drive hierarchy node was yielded
assert len(hierarchy_nodes) >= 1
def test_get_drive_items_uses_delta_when_no_folder_path(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When folder_path is None, _get_drive_items_for_drive_name should use delta."""
connector = _build_connector([_FakeDrive("Documents")])
site = SiteDescriptor(
url="https://example.sharepoint.com/sites/sample",
drive_name="Documents",
folder_path=None,
)
called_method: list[str] = []
def fake_delta(
self: SharepointConnector, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> Generator[DriveItemData, None, None]:
called_method.append("delta")
yield _SAMPLE_ITEM
def fake_paged(
self: SharepointConnector, # noqa: ARG001
drive_id: str, # noqa: ARG001
folder_path: str | None = None, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> Generator[DriveItemData, None, None]:
called_method.append("paged")
yield _SAMPLE_ITEM
monkeypatch.setattr(SharepointConnector, "_iter_drive_items_delta", fake_delta)
monkeypatch.setattr(SharepointConnector, "_iter_drive_items_paged", fake_paged)
items, _ = connector._get_drive_items_for_drive_name(site, "Documents")
list(items)
assert called_method == ["delta"]
def test_get_drive_items_uses_paged_when_folder_path_set(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When folder_path is set, _get_drive_items_for_drive_name should use BFS."""
connector = _build_connector([_FakeDrive("Documents")])
site = SiteDescriptor(
url="https://example.sharepoint.com/sites/sample",
drive_name="Documents",
folder_path="Engineering/Docs",
)
called_method: list[str] = []
def fake_delta(
self: SharepointConnector, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> Generator[DriveItemData, None, None]:
called_method.append("delta")
yield _SAMPLE_ITEM
def fake_paged(
self: SharepointConnector, # noqa: ARG001
drive_id: str, # noqa: ARG001
folder_path: str | None = None, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> Generator[DriveItemData, None, None]:
called_method.append("paged")
yield _SAMPLE_ITEM
monkeypatch.setattr(SharepointConnector, "_iter_drive_items_delta", fake_delta)
monkeypatch.setattr(SharepointConnector, "_iter_drive_items_paged", fake_paged)
items, _ = connector._get_drive_items_for_drive_name(site, "Documents")
list(items)
assert called_method == ["paged"]
def test_iter_drive_items_delta_uses_timestamp_token(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Delta iteration should pass the start time as a URL token for incremental sync."""
connector = SharepointConnector()
captured_urls: list[str] = []
def fake_graph_api_get_json(
self: SharepointConnector, # noqa: ARG001
url: str,
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
captured_urls.append(url)
return {
"value": [
{
"id": "file-1",
"name": "report.docx",
"webUrl": "https://example.sharepoint.com/report.docx",
"file": {
"mimeType": "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
},
"lastModifiedDateTime": "2025-06-15T12:00:00Z",
"parentReference": {"path": "/drives/d1/root:", "driveId": "d1"},
}
],
"@odata.deltaLink": "https://graph.microsoft.com/v1.0/drives/d1/root/delta?token=final",
}
monkeypatch.setattr(
SharepointConnector, "_graph_api_get_json", fake_graph_api_get_json
)
start = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
items = list(connector._iter_drive_items_delta("d1", start=start))
assert len(items) == 1
assert items[0].id == "file-1"
assert len(captured_urls) == 1
assert "token=2025-06-01T00%3A00%3A00Z" in captured_urls[0]
def test_iter_drive_items_delta_full_crawl_when_no_start(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Delta iteration without a start time should do a full enumeration (no token)."""
connector = SharepointConnector()
captured_urls: list[str] = []
def fake_graph_api_get_json(
self: SharepointConnector, # noqa: ARG001
url: str,
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
captured_urls.append(url)
return {
"value": [],
"@odata.deltaLink": "https://graph.microsoft.com/v1.0/drives/d1/root/delta?token=final",
}
monkeypatch.setattr(
SharepointConnector, "_graph_api_get_json", fake_graph_api_get_json
)
list(connector._iter_drive_items_delta("d1"))
assert len(captured_urls) == 1
assert "token=" not in captured_urls[0]
assert captured_urls[0].endswith("/drives/d1/root/delta")
def test_iter_drive_items_delta_skips_folders_and_deleted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Delta results with folder or deleted facets should be skipped."""
connector = SharepointConnector()
def fake_graph_api_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
return {
"value": [
{"id": "folder-1", "name": "Docs", "folder": {"childCount": 5}},
{"id": "deleted-1", "name": "old.txt", "deleted": {"state": "deleted"}},
{
"id": "file-1",
"name": "keep.pdf",
"webUrl": "https://example.sharepoint.com/keep.pdf",
"file": {"mimeType": "application/pdf"},
"lastModifiedDateTime": "2025-06-15T12:00:00Z",
"parentReference": {"path": "/drives/d1/root:", "driveId": "d1"},
},
],
"@odata.deltaLink": "https://graph.microsoft.com/v1.0/drives/d1/root/delta?token=final",
}
monkeypatch.setattr(
SharepointConnector, "_graph_api_get_json", fake_graph_api_get_json
)
items = list(connector._iter_drive_items_delta("d1"))
assert len(items) == 1
assert items[0].id == "file-1"
def test_iter_drive_items_delta_handles_410_gone(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""On 410 Gone, delta should fall back to full enumeration."""
import requests as req
connector = SharepointConnector()
call_count = 0
def fake_graph_api_get_json(
self: SharepointConnector, # noqa: ARG001
url: str,
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
nonlocal call_count
call_count += 1
if call_count == 1 and "token=" in url:
response = req.Response()
response.status_code = 410
raise req.HTTPError(response=response)
return {
"value": [
{
"id": "file-1",
"name": "doc.pdf",
"webUrl": "https://example.sharepoint.com/doc.pdf",
"file": {"mimeType": "application/pdf"},
"lastModifiedDateTime": "2025-06-15T12:00:00Z",
"parentReference": {"path": "/drives/d1/root:", "driveId": "d1"},
}
],
"@odata.deltaLink": "https://graph.microsoft.com/v1.0/drives/d1/root/delta?token=final",
}
monkeypatch.setattr(
SharepointConnector, "_graph_api_get_json", fake_graph_api_get_json
)
start = datetime(2025, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
items = list(connector._iter_drive_items_delta("d1", start=start))
assert len(items) == 1
assert items[0].id == "file-1"
assert call_count == 2

View File

@@ -839,6 +839,42 @@ export const connectorConfigs: Record<
description:
"Index aspx-pages of all SharePoint sites defined above, even if a library or folder is specified.",
},
{
type: "text",
query: "Microsoft Authority Host:",
label: "Authority Host",
name: "authority_host",
optional: true,
default: "https://login.microsoftonline.com",
description:
"The Microsoft identity authority host used for authentication. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://login.microsoftonline.us",
},
{
type: "text",
query: "Microsoft Graph API Host:",
label: "Graph API Host",
name: "graph_api_host",
optional: true,
default: "https://graph.microsoft.com",
description:
"The Microsoft Graph API host. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://graph.microsoft.us",
},
{
type: "text",
query: "SharePoint Domain Suffix:",
label: "SharePoint Domain Suffix",
name: "sharepoint_domain_suffix",
optional: true,
default: "sharepoint.com",
description:
"The domain suffix for SharePoint sites (e.g. sharepoint.com). " +
"For most deployments, leave as default. " +
"For GCC High, use sharepoint.us",
},
],
},
teams: {
@@ -853,7 +889,32 @@ export const connectorConfigs: Record<
description: `Specify 0 or more Teams to index. For example, specifying the Team 'Support' for the 'onyxai' Org will cause us to only index messages sent in channels belonging to the 'Support' Team. If no Teams are specified, all Teams in your organization will be indexed.`,
},
],
advanced_values: [],
advanced_values: [
{
type: "text",
query: "Microsoft Authority Host:",
label: "Authority Host",
name: "authority_host",
optional: true,
default: "https://login.microsoftonline.com",
description:
"The Microsoft identity authority host used for authentication. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://login.microsoftonline.us",
},
{
type: "text",
query: "Microsoft Graph API Host:",
label: "Graph API Host",
name: "graph_api_host",
optional: true,
default: "https://graph.microsoft.com",
description:
"The Microsoft Graph API host. " +
"For most deployments, leave as default. " +
"For GCC High / DoD, use https://graph.microsoft.us",
},
],
},
discourse: {
description: "Configure Discourse connector",
@@ -1881,10 +1942,15 @@ export interface SharepointConfig {
sites?: string[];
include_site_pages?: boolean;
include_site_documents?: boolean;
authority_host?: string;
graph_api_host?: string;
sharepoint_domain_suffix?: string;
}
export interface TeamsConfig {
teams?: string[];
authority_host?: string;
graph_api_host?: string;
}
export interface DiscourseConfig {
@@ -1905,6 +1971,8 @@ export interface DrupalWikiConfig {
export interface TeamsConfig {
teams?: string[];
authority_host?: string;
graph_api_host?: string;
}
export interface ProductboardConfig {}