mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-08 08:22:42 +00:00
Compare commits
3 Commits
cli/v0.2.1
...
temp/pr-61
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0a74c30e5 | ||
|
|
ad22d98007 | ||
|
|
2a52023c1d |
@@ -105,44 +105,37 @@ class TeamsConnector(
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
|
||||
|
||||
# Determine timeout based on special characters
|
||||
# Check if any requested teams have special characters that need client-side filtering
|
||||
has_special_chars = _has_odata_incompatible_chars(self.requested_team_list)
|
||||
timeout = 30 if has_special_chars else 10
|
||||
if has_special_chars:
|
||||
logger.info(
|
||||
"Some requested team names contain special characters (&, (, )) that require "
|
||||
"client-side filtering during data retrieval."
|
||||
)
|
||||
|
||||
try:
|
||||
# Minimal call to confirm we can retrieve Teams
|
||||
# Use longer timeout if team names have special characters (requires client-side filtering)
|
||||
# For validation, do a lightweight check instead of full team search
|
||||
logger.info(
|
||||
f"Requested team count: {len(self.requested_team_list) if self.requested_team_list else 0}, "
|
||||
f"Has special chars: {has_special_chars}, "
|
||||
f"Timeout: {timeout}s"
|
||||
f"Has special chars: {has_special_chars}"
|
||||
)
|
||||
|
||||
found_teams = run_with_timeout(
|
||||
# Minimal validation: just check if we can access the teams endpoint
|
||||
timeout = 10 # Short timeout for basic validation
|
||||
validation_query = self.graph_client.teams.get().top(1)
|
||||
run_with_timeout(
|
||||
timeout=timeout,
|
||||
func=_collect_all_teams,
|
||||
graph_client=self.graph_client,
|
||||
requested=self.requested_team_list,
|
||||
func=lambda: validation_query.execute_query(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Teams validation successful - " f"Found {len(found_teams)} team(s)"
|
||||
)
|
||||
logger.info("Teams validation successful - Access to teams endpoint confirmed")
|
||||
|
||||
except TimeoutError as e:
|
||||
if has_special_chars:
|
||||
raise ConnectorValidationError(
|
||||
f"Timeout while fetching Teams (waited {timeout}s). "
|
||||
f"Team names with special characters (&, (, )) require fetching all teams "
|
||||
f"for client-side filtering, which can take longer. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Timeout while fetching Teams (waited {timeout}s). "
|
||||
f"This may indicate network issues or a large number of teams. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Timeout while validating Teams access (waited {timeout}s). "
|
||||
f"This may indicate network issues or authentication problems. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
|
||||
except ClientRequestException as e:
|
||||
if not e.response:
|
||||
@@ -176,12 +169,6 @@ class TeamsConnector(
|
||||
f"Unexpected error during Teams validation: {e}"
|
||||
)
|
||||
|
||||
if not found_teams:
|
||||
raise ConnectorValidationError(
|
||||
"No Teams found for the given credentials. "
|
||||
"Either there are no Teams in this tenant, or your app does not have permission to view them."
|
||||
)
|
||||
|
||||
# impls for CheckpointedConnector
|
||||
|
||||
def build_dummy_checkpoint(self) -> TeamsCheckpoint:
|
||||
@@ -262,8 +249,8 @@ class TeamsConnector(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
_end: SecondsSinceUnixEpoch | None = None,
|
||||
_callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
start = start or 0
|
||||
|
||||
@@ -274,7 +261,7 @@ class TeamsConnector(
|
||||
|
||||
for team in teams:
|
||||
if not team.id:
|
||||
logger.warn(f"Expected a team with an id, instead got no id: {team=}")
|
||||
logger.warning(f"Expected a team with an id, instead got no id: {team=}")
|
||||
continue
|
||||
|
||||
channels = _collect_all_channels_from_team(
|
||||
@@ -283,7 +270,7 @@ class TeamsConnector(
|
||||
|
||||
for channel in channels:
|
||||
if not channel.id:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected a channel with an id, instead got no id: {channel=}"
|
||||
)
|
||||
continue
|
||||
@@ -319,18 +306,68 @@ class TeamsConnector(
|
||||
slim_doc_buffer = []
|
||||
|
||||
|
||||
def _has_odata_incompatible_chars(team_names: list[str] | None) -> bool:
|
||||
"""Check if any team name contains characters that break OData filters.
|
||||
def _escape_odata_string(name: str) -> str:
|
||||
"""Escape special characters for OData string literals.
|
||||
|
||||
The &, (, and ) characters are not allowed in OData string literals and are
|
||||
reserved characters in OData syntax. Server-side filtering is not possible for
|
||||
team names containing these characters.
|
||||
Uses proper OData v4 string literal escaping:
|
||||
- Single quotes: ' becomes ''
|
||||
- Other characters are handled by using contains() instead of eq for problematic cases
|
||||
"""
|
||||
# Escape single quotes for OData syntax (replace ' with '')
|
||||
escaped = name.replace("'", "''")
|
||||
return escaped
|
||||
|
||||
|
||||
def _has_odata_incompatible_chars(team_names: list[str] | None) -> bool:
|
||||
"""Check if any team name contains characters that break Microsoft Graph OData filters.
|
||||
|
||||
The Microsoft Graph Teams API has limited OData support. Characters like
|
||||
&, (, and ) cause parsing errors and require client-side filtering instead.
|
||||
"""
|
||||
if not team_names:
|
||||
return False
|
||||
return any(char in name for name in team_names for char in ["&", "(", ")"])
|
||||
|
||||
|
||||
def _can_use_odata_filter(team_names: list[str] | None) -> tuple[bool, list[str], list[str]]:
|
||||
"""Determine which teams can use OData filtering vs client-side filtering.
|
||||
|
||||
Microsoft Graph /teams endpoint OData limitations:
|
||||
- Only supports basic 'eq' operators in filters
|
||||
- No 'contains', 'startswith', or other advanced operators
|
||||
- Special characters (&, (, )) break OData parsing
|
||||
|
||||
Returns:
|
||||
tuple: (can_use_odata, safe_names, problematic_names)
|
||||
"""
|
||||
if not team_names:
|
||||
return False, [], []
|
||||
|
||||
safe_names = []
|
||||
problematic_names = []
|
||||
|
||||
for name in team_names:
|
||||
if any(char in name for char in ["&", "(", ")"]):
|
||||
problematic_names.append(name)
|
||||
else:
|
||||
safe_names.append(name)
|
||||
|
||||
return bool(safe_names), safe_names, problematic_names
|
||||
|
||||
|
||||
def _build_simple_odata_filter(safe_names: list[str]) -> str | None:
|
||||
"""Build simple OData filter using only 'eq' operators for safe names."""
|
||||
if not safe_names:
|
||||
return None
|
||||
|
||||
filter_parts = []
|
||||
for name in safe_names:
|
||||
escaped_name = _escape_odata_string(name)
|
||||
filter_parts.append(f"displayName eq '{escaped_name}'")
|
||||
|
||||
return " or ".join(filter_parts)
|
||||
|
||||
|
||||
def _construct_semantic_identifier(channel: Channel, top_message: Message) -> str:
|
||||
top_message_user_name: str
|
||||
|
||||
@@ -340,7 +377,7 @@ def _construct_semantic_identifier(channel: Channel, top_message: Message) -> st
|
||||
user_display_name if user_display_name else "Unknown User"
|
||||
)
|
||||
else:
|
||||
logger.warn(f"Message {top_message=} has no `from.user` field")
|
||||
logger.warning(f"Message {top_message=} has no `from.user` field")
|
||||
top_message_user_name = "Unknown User"
|
||||
|
||||
top_message_content = top_message.body.content or ""
|
||||
@@ -433,45 +470,72 @@ def _collect_all_teams(
|
||||
graph_client: GraphClient,
|
||||
requested: list[str] | None = None,
|
||||
) -> list[Team]:
|
||||
"""Collect teams from Microsoft Graph using appropriate filtering strategy.
|
||||
|
||||
For teams with special characters (&, (, )), uses client-side filtering
|
||||
with paginated search. For teams without special characters, uses efficient
|
||||
OData server-side filtering.
|
||||
|
||||
Args:
|
||||
graph_client: Authenticated Microsoft Graph client
|
||||
requested: List of team names to find, or None for all teams
|
||||
|
||||
Returns:
|
||||
List of Team objects matching the requested names
|
||||
"""
|
||||
teams: list[Team] = []
|
||||
next_url: str | None = None
|
||||
|
||||
# Check if team names have special characters that break OData filters
|
||||
has_special_chars = _has_odata_incompatible_chars(requested)
|
||||
if (
|
||||
has_special_chars and requested
|
||||
): # requested must exist if has_special_chars is True
|
||||
logger.info(
|
||||
f"Team name(s) contain special characters (&, (, or )) which are not supported "
|
||||
f"in OData string literals. Fetching all teams and using client-side filtering. "
|
||||
f"Count: {len(requested)}"
|
||||
)
|
||||
# Determine filtering strategy based on Microsoft Graph limitations
|
||||
if not requested:
|
||||
# No specific teams requested - return empty list (avoid fetching all teams)
|
||||
logger.info("No specific teams requested - returning empty list")
|
||||
return []
|
||||
|
||||
# Build OData filter for requested teams (only if we didn't already return from raw HTTP above)
|
||||
filter = None
|
||||
use_filter = (
|
||||
bool(requested) and not has_special_chars
|
||||
) # Skip OData for special chars (fallback to client-side)
|
||||
if use_filter and requested:
|
||||
filter_parts = []
|
||||
for name in requested:
|
||||
# Escape single quotes for OData syntax (replace ' with '')
|
||||
escaped_name = name.replace("'", "''")
|
||||
filter_parts.append(f"displayName eq '{escaped_name}'")
|
||||
filter = " or ".join(filter_parts)
|
||||
_, safe_names, problematic_names = _can_use_odata_filter(requested)
|
||||
|
||||
if problematic_names and not safe_names:
|
||||
# ALL requested teams have special characters - cannot use OData filtering
|
||||
logger.info(
|
||||
f"All requested team names contain special characters (&, (, )) which require "
|
||||
f"client-side filtering. Using basic /teams endpoint with pagination. "
|
||||
f"Teams: {problematic_names}"
|
||||
)
|
||||
# Use unfiltered query with pagination limit to avoid fetching too many teams
|
||||
use_client_side_filtering = True
|
||||
odata_filter = None
|
||||
elif problematic_names and safe_names:
|
||||
# Mixed scenario - need to fetch more teams to find the problematic ones
|
||||
logger.info(
|
||||
f"Mixed team types: will use client-side filtering for all. "
|
||||
f"Safe names: {safe_names}, Special char names: {problematic_names}"
|
||||
)
|
||||
use_client_side_filtering = True
|
||||
odata_filter = None
|
||||
elif safe_names:
|
||||
# All names are safe - use OData filtering
|
||||
logger.info(f"Using OData filtering for all requested teams: {safe_names}")
|
||||
use_client_side_filtering = False
|
||||
odata_filter = _build_simple_odata_filter(safe_names)
|
||||
else:
|
||||
# No valid names
|
||||
return []
|
||||
|
||||
# Track pagination to avoid fetching too many teams for client-side filtering
|
||||
max_pages = 200
|
||||
page_count = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
if filter:
|
||||
# Use normal filter for teams without special characters
|
||||
query = graph_client.teams.get().filter(filter)
|
||||
# Add header to work around Microsoft Graph API ampersand bug
|
||||
query.before_execute(lambda req: _add_prefer_header(request=req))
|
||||
if use_client_side_filtering:
|
||||
# Use basic /teams endpoint with top parameter to limit results per page
|
||||
query = graph_client.teams.get().top(50) # Limit to 50 teams per page
|
||||
else:
|
||||
query = graph_client.teams.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
)
|
||||
# Use OData filter with only 'eq' operators
|
||||
query = graph_client.teams.get().filter(odata_filter)
|
||||
|
||||
# Add header to work around Microsoft Graph API issues
|
||||
query.before_execute(lambda req: _add_prefer_header(request=req))
|
||||
|
||||
if next_url:
|
||||
url = next_url
|
||||
@@ -481,17 +545,17 @@ def _collect_all_teams(
|
||||
|
||||
team_collection = query.execute_query()
|
||||
except (ClientRequestException, ValueError) as e:
|
||||
# If OData filter fails, fallback to client-side filtering
|
||||
if use_filter:
|
||||
logger.warning(
|
||||
f"OData filter failed with {type(e).__name__}: {e}. "
|
||||
f"Falling back to client-side filtering."
|
||||
)
|
||||
use_filter = False
|
||||
filter = None
|
||||
# If OData filter fails, fall back to client-side filtering
|
||||
if not use_client_side_filtering and odata_filter:
|
||||
logger.warning(f"OData filter failed: {e}. Falling back to client-side filtering.")
|
||||
use_client_side_filtering = True
|
||||
odata_filter = None
|
||||
teams = []
|
||||
next_url = None
|
||||
page_count = 0
|
||||
continue
|
||||
# If client-side approach also fails, re-raise
|
||||
logger.error(f"Teams query failed: {e}")
|
||||
raise
|
||||
|
||||
filtered_teams = (
|
||||
@@ -501,6 +565,30 @@ def _collect_all_teams(
|
||||
)
|
||||
teams.extend(filtered_teams)
|
||||
|
||||
# For client-side filtering, check if we found all requested teams or hit page limit
|
||||
if use_client_side_filtering:
|
||||
page_count += 1
|
||||
found_team_names = {team.display_name for team in teams if team.display_name}
|
||||
requested_set = set(requested)
|
||||
|
||||
# Log progress every 10 pages to avoid excessive logging
|
||||
if page_count % 10 == 0:
|
||||
logger.info(
|
||||
f"Searched {page_count} pages, found {len(found_team_names)} matching teams so far"
|
||||
)
|
||||
|
||||
# Stop if we found all requested teams or hit the page limit
|
||||
if requested_set.issubset(found_team_names):
|
||||
logger.info(f"Found all requested teams after {page_count} pages")
|
||||
break
|
||||
elif page_count >= max_pages:
|
||||
logger.warning(
|
||||
f"Reached maximum page limit ({max_pages}) while searching for teams. "
|
||||
f"Found: {found_team_names & requested_set}, "
|
||||
f"Missing: {requested_set - found_team_names}"
|
||||
)
|
||||
break
|
||||
|
||||
if not team_collection.has_next:
|
||||
break
|
||||
|
||||
@@ -514,6 +602,53 @@ def _collect_all_teams(
|
||||
return teams
|
||||
|
||||
|
||||
def _normalize_team_name(name: str) -> str:
|
||||
"""Normalize team name for flexible matching."""
|
||||
if not name:
|
||||
return ""
|
||||
# Convert to lowercase and strip whitespace for case-insensitive matching
|
||||
return name.lower().strip()
|
||||
|
||||
|
||||
def _matches_requested_team(team_display_name: str, requested: list[str]) -> bool:
|
||||
"""Check if team display name matches any of the requested team names.
|
||||
|
||||
Uses flexible matching to handle slight variations in team names.
|
||||
"""
|
||||
if not requested or not team_display_name:
|
||||
return not requested # If no teams requested, match all; if no name, don't match
|
||||
|
||||
normalized_team_name = _normalize_team_name(team_display_name)
|
||||
|
||||
for requested_name in requested:
|
||||
normalized_requested = _normalize_team_name(requested_name)
|
||||
|
||||
# Exact match after normalization
|
||||
if normalized_team_name == normalized_requested:
|
||||
return True
|
||||
|
||||
# Flexible matching - check if team name contains all significant words
|
||||
# This helps with slight variations in formatting
|
||||
team_words = set(normalized_team_name.split())
|
||||
requested_words = set(normalized_requested.split())
|
||||
|
||||
# If the requested name has special characters, split on those too
|
||||
for char in ['&', '(', ')']:
|
||||
if char in normalized_requested:
|
||||
# Split on special characters and add words
|
||||
parts = normalized_requested.replace(char, ' ').split()
|
||||
requested_words.update(parts)
|
||||
|
||||
# Remove very short words that aren't meaningful
|
||||
meaningful_requested_words = {word for word in requested_words if len(word) >= 3}
|
||||
|
||||
# Check if team name contains most of the meaningful words
|
||||
if meaningful_requested_words and len(meaningful_requested_words & team_words) >= len(meaningful_requested_words) * 0.7:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _filter_team(
|
||||
team: Team,
|
||||
requested: list[str] | None = None,
|
||||
@@ -522,7 +657,7 @@ def _filter_team(
|
||||
Returns the true if:
|
||||
- Team is not expired / deleted
|
||||
- Team has a display-name and ID
|
||||
- Team display-name is in the requested teams list
|
||||
- Team display-name matches any of the requested teams (with flexible matching)
|
||||
|
||||
Otherwise, returns false.
|
||||
"""
|
||||
@@ -530,7 +665,7 @@ def _filter_team(
|
||||
if not team.id or not team.display_name:
|
||||
return False
|
||||
|
||||
if requested and team.display_name not in requested:
|
||||
if not _matches_requested_team(team.display_name, requested):
|
||||
return False
|
||||
|
||||
props = team.properties
|
||||
|
||||
@@ -6,48 +6,41 @@ from onyx.connectors.teams.connector import _collect_all_teams
|
||||
|
||||
|
||||
def test_special_characters_in_team_names() -> None:
|
||||
"""Test that team names with special characters skip OData and use get_all()."""
|
||||
"""Test that team names with special characters use client-side filtering."""
|
||||
mock_graph_client = MagicMock()
|
||||
|
||||
# Mock successful responses
|
||||
# Mock team with special characters
|
||||
mock_team = MagicMock()
|
||||
mock_team.id = "test-id"
|
||||
mock_team.display_name = "Research & Development (R&D) Team"
|
||||
mock_team.properties = {}
|
||||
|
||||
# Mock successful responses for client-side filtering
|
||||
mock_team_collection = MagicMock()
|
||||
mock_team_collection.has_next = False
|
||||
mock_team_collection.__iter__ = lambda self: iter([])
|
||||
|
||||
mock_get_all_query = MagicMock()
|
||||
mock_get_all_query.execute_query.return_value = mock_team_collection
|
||||
mock_graph_client.teams.get_all = MagicMock(return_value=mock_get_all_query)
|
||||
mock_team_collection.__iter__ = lambda self: iter([mock_team])
|
||||
|
||||
mock_get_query = MagicMock()
|
||||
mock_filter_query = MagicMock()
|
||||
mock_filter_query.execute_query.return_value = mock_team_collection
|
||||
mock_get_query.filter.return_value = mock_filter_query
|
||||
mock_top_query = MagicMock()
|
||||
mock_top_query.execute_query.return_value = mock_team_collection
|
||||
mock_get_query.top.return_value = mock_top_query
|
||||
mock_graph_client.teams.get = MagicMock(return_value=mock_get_query)
|
||||
|
||||
# Test with the actual customer's problematic team name (has &, parentheses, spaces)
|
||||
# This should skip OData filtering entirely and use get_all() for client-side filtering
|
||||
_collect_all_teams(mock_graph_client, ["Grainger Data & Analytics (GDA) Users"])
|
||||
# Test with team name containing special characters (has &, parentheses)
|
||||
# This should use client-side filtering (get().top()) instead of OData filtering
|
||||
result = _collect_all_teams(mock_graph_client, ["Research & Development (R&D) Team"])
|
||||
|
||||
# Verify that get_all() was called (NOT get().filter())
|
||||
# because special characters are not supported in OData string literals
|
||||
mock_graph_client.teams.get_all.assert_called()
|
||||
mock_graph_client.teams.get.assert_not_called()
|
||||
# Verify that get().top() was called for client-side filtering
|
||||
mock_graph_client.teams.get.assert_called()
|
||||
mock_get_query.top.assert_called_with(50)
|
||||
|
||||
# Reset mocks
|
||||
mock_graph_client.reset_mock()
|
||||
mock_get_all_query.execute_query.return_value = mock_team_collection
|
||||
mock_filter_query.execute_query.return_value = mock_team_collection
|
||||
|
||||
# Test that OData filter failure falls back to get_all()
|
||||
mock_filter_query.execute_query.side_effect = ValueError(
|
||||
"OData query parsing error"
|
||||
)
|
||||
_collect_all_teams(mock_graph_client, ["Simple Team"])
|
||||
mock_graph_client.teams.get_all.assert_called()
|
||||
# Verify the team was found through client-side filtering
|
||||
assert len(result) == 1
|
||||
assert result[0].display_name == "Research & Development (R&D) Team"
|
||||
|
||||
|
||||
def test_single_quote_escaping() -> None:
|
||||
"""Test that team names with single quotes are properly escaped for OData."""
|
||||
"""Test that team names with single quotes use OData filtering with proper escaping."""
|
||||
mock_graph_client = MagicMock()
|
||||
|
||||
# Mock successful responses
|
||||
@@ -57,14 +50,15 @@ def test_single_quote_escaping() -> None:
|
||||
|
||||
mock_get_query = MagicMock()
|
||||
mock_filter_query = MagicMock()
|
||||
mock_filter_query.before_execute = MagicMock(return_value=mock_filter_query)
|
||||
mock_filter_query.execute_query.return_value = mock_team_collection
|
||||
mock_get_query.filter.return_value = mock_filter_query
|
||||
mock_graph_client.teams.get = MagicMock(return_value=mock_get_query)
|
||||
|
||||
# Test with a team name containing a single quote
|
||||
# Test with a team name containing a single quote (no &, (, ) so uses OData)
|
||||
_collect_all_teams(mock_graph_client, ["Team's Group"])
|
||||
|
||||
# Verify OData filter was used
|
||||
# Verify OData filter was used (since no special characters)
|
||||
mock_graph_client.teams.get.assert_called()
|
||||
mock_get_query.filter.assert_called_once()
|
||||
|
||||
@@ -74,3 +68,29 @@ def test_single_quote_escaping() -> None:
|
||||
assert (
|
||||
filter_arg == expected_filter
|
||||
), f"Expected: {expected_filter}, Got: {filter_arg}"
|
||||
|
||||
|
||||
def test_helper_functions() -> None:
|
||||
"""Test the helper functions for team name processing."""
|
||||
from onyx.connectors.teams.connector import (
|
||||
_escape_odata_string,
|
||||
_has_odata_incompatible_chars,
|
||||
_can_use_odata_filter,
|
||||
)
|
||||
|
||||
# Test OData string escaping
|
||||
assert _escape_odata_string("Team's Group") == "Team''s Group"
|
||||
assert _escape_odata_string("Normal Team") == "Normal Team"
|
||||
|
||||
# Test special character detection
|
||||
assert _has_odata_incompatible_chars(["R&D Team"]) == True
|
||||
assert _has_odata_incompatible_chars(["Team (Alpha)"]) == True
|
||||
assert _has_odata_incompatible_chars(["Normal Team"]) == False
|
||||
assert _has_odata_incompatible_chars([]) == False
|
||||
assert _has_odata_incompatible_chars(None) == False
|
||||
|
||||
# Test filtering strategy determination
|
||||
can_use, safe, problematic = _can_use_odata_filter(["Normal Team", "R&D Team"])
|
||||
assert can_use == True
|
||||
assert "Normal Team" in safe
|
||||
assert "R&D Team" in problematic
|
||||
|
||||
Reference in New Issue
Block a user