Compare commits

...

3 Commits

Author SHA1 Message Date
Jessica Singh
a0a74c30e5 Merge branch 'main' into teams-ampersand-fix-3 2025-11-13 13:40:28 -08:00
Alex Kim
ad22d98007 Merge branch 'main' into teams-ampersand-fix-3 2025-11-12 19:17:14 -05:00
Alexander Kim
2a52023c1d fix 2025-11-07 10:12:32 -05:00
2 changed files with 269 additions and 114 deletions

View File

@@ -105,44 +105,37 @@ class TeamsConnector(
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
# Determine timeout based on special characters
# Check if any requested teams have special characters that need client-side filtering
has_special_chars = _has_odata_incompatible_chars(self.requested_team_list)
timeout = 30 if has_special_chars else 10
if has_special_chars:
logger.info(
"Some requested team names contain special characters (&, (, )) that require "
"client-side filtering during data retrieval."
)
try:
# Minimal call to confirm we can retrieve Teams
# Use longer timeout if team names have special characters (requires client-side filtering)
# For validation, do a lightweight check instead of full team search
logger.info(
f"Requested team count: {len(self.requested_team_list) if self.requested_team_list else 0}, "
f"Has special chars: {has_special_chars}, "
f"Timeout: {timeout}s"
f"Has special chars: {has_special_chars}"
)
found_teams = run_with_timeout(
# Minimal validation: just check if we can access the teams endpoint
timeout = 10 # Short timeout for basic validation
validation_query = self.graph_client.teams.get().top(1)
run_with_timeout(
timeout=timeout,
func=_collect_all_teams,
graph_client=self.graph_client,
requested=self.requested_team_list,
func=lambda: validation_query.execute_query(),
)
logger.info(
f"Teams validation successful - " f"Found {len(found_teams)} team(s)"
)
logger.info("Teams validation successful - Access to teams endpoint confirmed")
except TimeoutError as e:
if has_special_chars:
raise ConnectorValidationError(
f"Timeout while fetching Teams (waited {timeout}s). "
f"Team names with special characters (&, (, )) require fetching all teams "
f"for client-side filtering, which can take longer. "
f"Error: {e}"
)
else:
raise ConnectorValidationError(
f"Timeout while fetching Teams (waited {timeout}s). "
f"This may indicate network issues or a large number of teams. "
f"Error: {e}"
)
raise ConnectorValidationError(
f"Timeout while validating Teams access (waited {timeout}s). "
f"This may indicate network issues or authentication problems. "
f"Error: {e}"
)
except ClientRequestException as e:
if not e.response:
@@ -176,12 +169,6 @@ class TeamsConnector(
f"Unexpected error during Teams validation: {e}"
)
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
# impls for CheckpointedConnector
def build_dummy_checkpoint(self) -> TeamsCheckpoint:
@@ -262,8 +249,8 @@ class TeamsConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
_end: SecondsSinceUnixEpoch | None = None,
_callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
start = start or 0
@@ -274,7 +261,7 @@ class TeamsConnector(
for team in teams:
if not team.id:
logger.warn(f"Expected a team with an id, instead got no id: {team=}")
logger.warning(f"Expected a team with an id, instead got no id: {team=}")
continue
channels = _collect_all_channels_from_team(
@@ -283,7 +270,7 @@ class TeamsConnector(
for channel in channels:
if not channel.id:
logger.warn(
logger.warning(
f"Expected a channel with an id, instead got no id: {channel=}"
)
continue
@@ -319,18 +306,68 @@ class TeamsConnector(
slim_doc_buffer = []
def _has_odata_incompatible_chars(team_names: list[str] | None) -> bool:
"""Check if any team name contains characters that break OData filters.
def _escape_odata_string(name: str) -> str:
"""Escape special characters for OData string literals.
The &, (, and ) characters are not allowed in OData string literals and are
reserved characters in OData syntax. Server-side filtering is not possible for
team names containing these characters.
Uses proper OData v4 string literal escaping:
- Single quotes: ' becomes ''
- Other characters are handled by using contains() instead of eq for problematic cases
"""
# Escape single quotes for OData syntax (replace ' with '')
escaped = name.replace("'", "''")
return escaped
def _has_odata_incompatible_chars(team_names: list[str] | None) -> bool:
"""Check if any team name contains characters that break Microsoft Graph OData filters.
The Microsoft Graph Teams API has limited OData support. Characters like
&, (, and ) cause parsing errors and require client-side filtering instead.
"""
if not team_names:
return False
return any(char in name for name in team_names for char in ["&", "(", ")"])
def _can_use_odata_filter(team_names: list[str] | None) -> tuple[bool, list[str], list[str]]:
"""Determine which teams can use OData filtering vs client-side filtering.
Microsoft Graph /teams endpoint OData limitations:
- Only supports basic 'eq' operators in filters
- No 'contains', 'startswith', or other advanced operators
- Special characters (&, (, )) break OData parsing
Returns:
tuple: (can_use_odata, safe_names, problematic_names)
"""
if not team_names:
return False, [], []
safe_names = []
problematic_names = []
for name in team_names:
if any(char in name for char in ["&", "(", ")"]):
problematic_names.append(name)
else:
safe_names.append(name)
return bool(safe_names), safe_names, problematic_names
def _build_simple_odata_filter(safe_names: list[str]) -> str | None:
"""Build simple OData filter using only 'eq' operators for safe names."""
if not safe_names:
return None
filter_parts = []
for name in safe_names:
escaped_name = _escape_odata_string(name)
filter_parts.append(f"displayName eq '{escaped_name}'")
return " or ".join(filter_parts)
def _construct_semantic_identifier(channel: Channel, top_message: Message) -> str:
top_message_user_name: str
@@ -340,7 +377,7 @@ def _construct_semantic_identifier(channel: Channel, top_message: Message) -> st
user_display_name if user_display_name else "Unknown User"
)
else:
logger.warn(f"Message {top_message=} has no `from.user` field")
logger.warning(f"Message {top_message=} has no `from.user` field")
top_message_user_name = "Unknown User"
top_message_content = top_message.body.content or ""
@@ -433,45 +470,72 @@ def _collect_all_teams(
graph_client: GraphClient,
requested: list[str] | None = None,
) -> list[Team]:
"""Collect teams from Microsoft Graph using appropriate filtering strategy.
For teams with special characters (&, (, )), uses client-side filtering
with paginated search. For teams without special characters, uses efficient
OData server-side filtering.
Args:
graph_client: Authenticated Microsoft Graph client
requested: List of team names to find, or None for all teams
Returns:
List of Team objects matching the requested names
"""
teams: list[Team] = []
next_url: str | None = None
# Check if team names have special characters that break OData filters
has_special_chars = _has_odata_incompatible_chars(requested)
if (
has_special_chars and requested
): # requested must exist if has_special_chars is True
logger.info(
f"Team name(s) contain special characters (&, (, or )) which are not supported "
f"in OData string literals. Fetching all teams and using client-side filtering. "
f"Count: {len(requested)}"
)
# Determine filtering strategy based on Microsoft Graph limitations
if not requested:
# No specific teams requested - return empty list (avoid fetching all teams)
logger.info("No specific teams requested - returning empty list")
return []
# Build OData filter for requested teams (only if we didn't already return from raw HTTP above)
filter = None
use_filter = (
bool(requested) and not has_special_chars
) # Skip OData for special chars (fallback to client-side)
if use_filter and requested:
filter_parts = []
for name in requested:
# Escape single quotes for OData syntax (replace ' with '')
escaped_name = name.replace("'", "''")
filter_parts.append(f"displayName eq '{escaped_name}'")
filter = " or ".join(filter_parts)
_, safe_names, problematic_names = _can_use_odata_filter(requested)
if problematic_names and not safe_names:
# ALL requested teams have special characters - cannot use OData filtering
logger.info(
f"All requested team names contain special characters (&, (, )) which require "
f"client-side filtering. Using basic /teams endpoint with pagination. "
f"Teams: {problematic_names}"
)
# Use unfiltered query with pagination limit to avoid fetching too many teams
use_client_side_filtering = True
odata_filter = None
elif problematic_names and safe_names:
# Mixed scenario - need to fetch more teams to find the problematic ones
logger.info(
f"Mixed team types: will use client-side filtering for all. "
f"Safe names: {safe_names}, Special char names: {problematic_names}"
)
use_client_side_filtering = True
odata_filter = None
elif safe_names:
# All names are safe - use OData filtering
logger.info(f"Using OData filtering for all requested teams: {safe_names}")
use_client_side_filtering = False
odata_filter = _build_simple_odata_filter(safe_names)
else:
# No valid names
return []
# Track pagination to avoid fetching too many teams for client-side filtering
max_pages = 200
page_count = 0
while True:
try:
if filter:
# Use normal filter for teams without special characters
query = graph_client.teams.get().filter(filter)
# Add header to work around Microsoft Graph API ampersand bug
query.before_execute(lambda req: _add_prefer_header(request=req))
if use_client_side_filtering:
# Use basic /teams endpoint with top parameter to limit results per page
query = graph_client.teams.get().top(50) # Limit to 50 teams per page
else:
query = graph_client.teams.get_all(
# explicitly needed because of incorrect type definitions provided by the `office365` library
page_loaded=lambda _: None
)
# Use OData filter with only 'eq' operators
query = graph_client.teams.get().filter(odata_filter)
# Add header to work around Microsoft Graph API issues
query.before_execute(lambda req: _add_prefer_header(request=req))
if next_url:
url = next_url
@@ -481,17 +545,17 @@ def _collect_all_teams(
team_collection = query.execute_query()
except (ClientRequestException, ValueError) as e:
# If OData filter fails, fallback to client-side filtering
if use_filter:
logger.warning(
f"OData filter failed with {type(e).__name__}: {e}. "
f"Falling back to client-side filtering."
)
use_filter = False
filter = None
# If OData filter fails, fall back to client-side filtering
if not use_client_side_filtering and odata_filter:
logger.warning(f"OData filter failed: {e}. Falling back to client-side filtering.")
use_client_side_filtering = True
odata_filter = None
teams = []
next_url = None
page_count = 0
continue
# If client-side approach also fails, re-raise
logger.error(f"Teams query failed: {e}")
raise
filtered_teams = (
@@ -501,6 +565,30 @@ def _collect_all_teams(
)
teams.extend(filtered_teams)
# For client-side filtering, check if we found all requested teams or hit page limit
if use_client_side_filtering:
page_count += 1
found_team_names = {team.display_name for team in teams if team.display_name}
requested_set = set(requested)
# Log progress every 10 pages to avoid excessive logging
if page_count % 10 == 0:
logger.info(
f"Searched {page_count} pages, found {len(found_team_names)} matching teams so far"
)
# Stop if we found all requested teams or hit the page limit
if requested_set.issubset(found_team_names):
logger.info(f"Found all requested teams after {page_count} pages")
break
elif page_count >= max_pages:
logger.warning(
f"Reached maximum page limit ({max_pages}) while searching for teams. "
f"Found: {found_team_names & requested_set}, "
f"Missing: {requested_set - found_team_names}"
)
break
if not team_collection.has_next:
break
@@ -514,6 +602,53 @@ def _collect_all_teams(
return teams
def _normalize_team_name(name: str) -> str:
"""Normalize team name for flexible matching."""
if not name:
return ""
# Convert to lowercase and strip whitespace for case-insensitive matching
return name.lower().strip()
def _matches_requested_team(team_display_name: str, requested: list[str]) -> bool:
"""Check if team display name matches any of the requested team names.
Uses flexible matching to handle slight variations in team names.
"""
if not requested or not team_display_name:
return not requested # If no teams requested, match all; if no name, don't match
normalized_team_name = _normalize_team_name(team_display_name)
for requested_name in requested:
normalized_requested = _normalize_team_name(requested_name)
# Exact match after normalization
if normalized_team_name == normalized_requested:
return True
# Flexible matching - check if team name contains all significant words
# This helps with slight variations in formatting
team_words = set(normalized_team_name.split())
requested_words = set(normalized_requested.split())
# If the requested name has special characters, split on those too
for char in ['&', '(', ')']:
if char in normalized_requested:
# Split on special characters and add words
parts = normalized_requested.replace(char, ' ').split()
requested_words.update(parts)
# Remove very short words that aren't meaningful
meaningful_requested_words = {word for word in requested_words if len(word) >= 3}
# Check if team name contains most of the meaningful words
if meaningful_requested_words and len(meaningful_requested_words & team_words) >= len(meaningful_requested_words) * 0.7:
return True
return False
def _filter_team(
team: Team,
requested: list[str] | None = None,
@@ -522,7 +657,7 @@ def _filter_team(
Returns the true if:
- Team is not expired / deleted
- Team has a display-name and ID
- Team display-name is in the requested teams list
- Team display-name matches any of the requested teams (with flexible matching)
Otherwise, returns false.
"""
@@ -530,7 +665,7 @@ def _filter_team(
if not team.id or not team.display_name:
return False
if requested and team.display_name not in requested:
if not _matches_requested_team(team.display_name, requested):
return False
props = team.properties

View File

@@ -6,48 +6,41 @@ from onyx.connectors.teams.connector import _collect_all_teams
def test_special_characters_in_team_names() -> None:
"""Test that team names with special characters skip OData and use get_all()."""
"""Test that team names with special characters use client-side filtering."""
mock_graph_client = MagicMock()
# Mock successful responses
# Mock team with special characters
mock_team = MagicMock()
mock_team.id = "test-id"
mock_team.display_name = "Research & Development (R&D) Team"
mock_team.properties = {}
# Mock successful responses for client-side filtering
mock_team_collection = MagicMock()
mock_team_collection.has_next = False
mock_team_collection.__iter__ = lambda self: iter([])
mock_get_all_query = MagicMock()
mock_get_all_query.execute_query.return_value = mock_team_collection
mock_graph_client.teams.get_all = MagicMock(return_value=mock_get_all_query)
mock_team_collection.__iter__ = lambda self: iter([mock_team])
mock_get_query = MagicMock()
mock_filter_query = MagicMock()
mock_filter_query.execute_query.return_value = mock_team_collection
mock_get_query.filter.return_value = mock_filter_query
mock_top_query = MagicMock()
mock_top_query.execute_query.return_value = mock_team_collection
mock_get_query.top.return_value = mock_top_query
mock_graph_client.teams.get = MagicMock(return_value=mock_get_query)
# Test with the actual customer's problematic team name (has &, parentheses, spaces)
# This should skip OData filtering entirely and use get_all() for client-side filtering
_collect_all_teams(mock_graph_client, ["Grainger Data & Analytics (GDA) Users"])
# Test with team name containing special characters (has &, parentheses)
# This should use client-side filtering (get().top()) instead of OData filtering
result = _collect_all_teams(mock_graph_client, ["Research & Development (R&D) Team"])
# Verify that get_all() was called (NOT get().filter())
# because special characters are not supported in OData string literals
mock_graph_client.teams.get_all.assert_called()
mock_graph_client.teams.get.assert_not_called()
# Verify that get().top() was called for client-side filtering
mock_graph_client.teams.get.assert_called()
mock_get_query.top.assert_called_with(50)
# Reset mocks
mock_graph_client.reset_mock()
mock_get_all_query.execute_query.return_value = mock_team_collection
mock_filter_query.execute_query.return_value = mock_team_collection
# Test that OData filter failure falls back to get_all()
mock_filter_query.execute_query.side_effect = ValueError(
"OData query parsing error"
)
_collect_all_teams(mock_graph_client, ["Simple Team"])
mock_graph_client.teams.get_all.assert_called()
# Verify the team was found through client-side filtering
assert len(result) == 1
assert result[0].display_name == "Research & Development (R&D) Team"
def test_single_quote_escaping() -> None:
"""Test that team names with single quotes are properly escaped for OData."""
"""Test that team names with single quotes use OData filtering with proper escaping."""
mock_graph_client = MagicMock()
# Mock successful responses
@@ -57,14 +50,15 @@ def test_single_quote_escaping() -> None:
mock_get_query = MagicMock()
mock_filter_query = MagicMock()
mock_filter_query.before_execute = MagicMock(return_value=mock_filter_query)
mock_filter_query.execute_query.return_value = mock_team_collection
mock_get_query.filter.return_value = mock_filter_query
mock_graph_client.teams.get = MagicMock(return_value=mock_get_query)
# Test with a team name containing a single quote
# Test with a team name containing a single quote (no &, (, ) so uses OData)
_collect_all_teams(mock_graph_client, ["Team's Group"])
# Verify OData filter was used
# Verify OData filter was used (since no special characters)
mock_graph_client.teams.get.assert_called()
mock_get_query.filter.assert_called_once()
@@ -74,3 +68,29 @@ def test_single_quote_escaping() -> None:
assert (
filter_arg == expected_filter
), f"Expected: {expected_filter}, Got: {filter_arg}"
def test_helper_functions() -> None:
"""Test the helper functions for team name processing."""
from onyx.connectors.teams.connector import (
_escape_odata_string,
_has_odata_incompatible_chars,
_can_use_odata_filter,
)
# Test OData string escaping
assert _escape_odata_string("Team's Group") == "Team''s Group"
assert _escape_odata_string("Normal Team") == "Normal Team"
# Test special character detection
assert _has_odata_incompatible_chars(["R&D Team"]) == True
assert _has_odata_incompatible_chars(["Team (Alpha)"]) == True
assert _has_odata_incompatible_chars(["Normal Team"]) == False
assert _has_odata_incompatible_chars([]) == False
assert _has_odata_incompatible_chars(None) == False
# Test filtering strategy determination
can_use, safe, problematic = _can_use_odata_filter(["Normal Team", "R&D Team"])
assert can_use == True
assert "Normal Team" in safe
assert "R&D Team" in problematic