Compare commits

...

5 Commits

Author SHA1 Message Date
Evan Lohn
d81337e345 mypy 2025-12-29 16:23:03 -08:00
Evan Lohn
df3c8982a1 mypy 2025-12-29 16:22:11 -08:00
Evan Lohn
dbb720e7f9 WIP 2025-12-29 16:20:45 -08:00
Evan Lohn
f68b9526fb remove future stuff 2025-12-29 16:20:45 -08:00
Evan Lohn
6460b5df4b refactor: drive connector 2025-12-29 16:20:45 -08:00
4 changed files with 231 additions and 91 deletions

View File

@@ -564,7 +564,7 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined]
class OnyxCallTypes(str, Enum):

View File

@@ -8,7 +8,6 @@ from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any
from typing import cast
from typing import Protocol
@@ -39,6 +38,9 @@ from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import (
get_files_by_web_view_links_batch,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
@@ -64,6 +66,7 @@ from onyx.connectors.google_utils.shared_constants import USER_FIELDS
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import Resolver
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
@@ -154,7 +157,9 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
Resolver,
):
def __init__(
self,
@@ -1130,26 +1135,67 @@ class GoogleDriveConnector(
end=end,
)
def _extract_docs_from_google_drive(
def _convert_retrieved_files_to_documents(
self,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
drive_files_iter: Iterator[RetrievedDriveFile],
include_permissions: bool,
) -> Iterator[Document | ConnectorFailure]:
"""
Retrieves and converts Google Drive files to documents.
Converts retrieved files to documents.
"""
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
files_batch: list[RetrievedDriveFile] = []
for retrieved_file in drive_files_iter:
if self.exclude_domain_link_only and has_link_only_permission(
retrieved_file.drive_file
):
continue
if retrieved_file.error is None:
files_batch.append(retrieved_file)
continue
# handle retrieval errors
failure_stage = retrieved_file.completion_stage.value
failure_message = f"retrieval failure during stage: {failure_stage},"
failure_message += f"user: {retrieved_file.user_email},"
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
self._convert_retrieved_file_to_document,
(retrieved_file, include_permissions),
)
for retrieved_file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
results_cleaned = [result for result in results if result is not None]
logger.debug(f"batch has {len(results_cleaned)} docs or failures")
yield from results_cleaned
def _convert_retrieved_file_to_document(
self,
retrieved_file: RetrievedDriveFile,
include_permissions: bool,
) -> Document | ConnectorFailure | None:
"""
Converts a retrieved file to a document.
"""
try:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
convert_drive_item_to_document,
return convert_drive_item_to_document(
self.creds,
self.allow_images,
self.size_threshold,
@@ -1161,83 +1207,15 @@ class GoogleDriveConnector(
if include_permissions
else None
),
)
# Fetch files in batches
batches_complete = 0
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
convert_func,
(
[file.user_email, self.primary_admin_email]
+ get_file_owners(
file.drive_file, self.primary_admin_email
),
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
logger.debug(
f"finished processing batch {batches_complete} with {len(results)} results"
)
docs_and_failures = [result for result in results if result is not None]
logger.debug(
f"batch {batches_complete} has {len(docs_and_failures)} docs or failures"
)
if docs_and_failures:
yield from docs_and_failures
batches_complete += 1
logger.debug(f"finished yielding batch {batches_complete}")
for retrieved_file in self._fetch_drive_items(
field_type=field_type,
checkpoint=checkpoint,
start=start,
end=end,
):
if self.exclude_domain_link_only and has_link_only_permission(
retrieved_file.drive_file
):
continue
if retrieved_file.error is None:
files_batch.append(retrieved_file)
continue
# handle retrieval errors
failure_stage = retrieved_file.completion_stage.value
failure_message = f"retrieval failure during stage: {failure_stage},"
failure_message += f"user: {retrieved_file.user_email},"
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
yield from _yield_batch(files_batch)
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
[retrieved_file.user_email, self.primary_admin_email]
+ get_file_owners(retrieved_file.drive_file, self.primary_admin_email),
retrieved_file.drive_file,
)
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
logger.exception(
f"Error extracting document: {retrieved_file.drive_file.get('name')} from Google Drive"
)
raise e
def _load_from_checkpoint(
@@ -1262,8 +1240,19 @@ class GoogleDriveConnector(
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
try:
yield from self._extract_docs_from_google_drive(
checkpoint, start, end, include_permissions
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
drive_files_iter = self._fetch_drive_items(
field_type=field_type,
checkpoint=checkpoint,
start=start,
end=end,
)
yield from self._convert_retrieved_files_to_documents(
drive_files_iter, include_permissions
)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
@@ -1300,6 +1289,43 @@ class GoogleDriveConnector(
start, end, checkpoint, include_permissions=True
)
@override
def resolve_errors(
self, errors: list[ConnectorFailure], include_permissions: bool = False
) -> Generator[Document | ConnectorFailure, None, None]:
"""Attempts to yield back ALL the documents described by the error, no checkpointing.
caller's responsibility is to delete the old connectorfailures and replace with the new ones.
"""
if self._creds is None or self._primary_admin_email is None:
raise RuntimeError(
"Credentials missing, should not call this method before calling load_credentials"
)
logger.info(f"Resolving {len(errors)} errors")
doc_ids = set(
failure.failed_document.document_id
for failure in errors
if failure.failed_document
)
service = get_drive_service(self.creds, self.primary_admin_email)
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
files = get_files_by_web_view_links_batch(service, list(doc_ids), field_type)
retrieved_iter = (
RetrievedDriveFile(
drive_file=file,
user_email=self.primary_admin_email,
completion_stage=DriveRetrievalStage.DONE,
)
for file in files.values()
)
yield from self._convert_retrieved_files_to_documents(
retrieved_iter, include_permissions
)
def _extract_slim_docs_from_google_drive(
self,
checkpoint: GoogleDriveCheckpoint,

View File

@@ -3,9 +3,14 @@ from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from enum import Enum
from typing import cast
from typing import Dict
from urllib.parse import parse_qs
from urllib.parse import urlparse
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import BatchHttpRequest # type: ignore
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
@@ -52,6 +57,8 @@ SLIM_FILE_FIELDS = (
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
MAX_BATCH_SIZE = 100
def generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
@@ -393,3 +400,98 @@ def get_root_folder_id(service: Resource) -> str:
.get(fileId="root", fields=GoogleFields.ID.value)
.execute()[GoogleFields.ID.value]
)
def _extract_file_id_from_web_view_link(web_view_link: str) -> str:
parsed = urlparse(web_view_link)
path_parts = [part for part in parsed.path.split("/") if part]
if "d" in path_parts:
idx = path_parts.index("d")
if idx + 1 < len(path_parts):
return path_parts[idx + 1]
query_params = parse_qs(parsed.query)
for key in ("id", "fileId"):
value = query_params.get(key)
if value and value[0]:
return value[0]
raise ValueError(
f"Unable to extract Drive file id from webViewLink: {web_view_link}"
)
def get_file_by_web_view_link(
service: GoogleDriveService,
web_view_link: str,
fields: str,
) -> GoogleDriveFileType:
"""
Retrieve a Google Drive file using its webViewLink.
"""
file_id = _extract_file_id_from_web_view_link(web_view_link)
return (
service.files()
.get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
.execute()
)
def get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
field_type: DriveFileFieldType,
) -> dict[str, GoogleDriveFileType]:
fields = _get_fields_for_file_type(field_type)
if len(web_view_links) <= MAX_BATCH_SIZE:
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
ret = {}
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
batch = web_view_links[i : i + MAX_BATCH_SIZE]
ret.update(_get_files_by_web_view_links_batch(service, batch, fields))
return ret
def _get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
fields: str,
) -> dict[str, GoogleDriveFileType]:
"""
Retrieve multiple Google Drive files using their webViewLinks in a single batch request.
Returns a dict mapping web_view_link to file metadata.
Failed requests (due to invalid links or API errors) are omitted from the result.
"""
def callback(
request_id: str, response: GoogleDriveFileType, exception: Exception | None
) -> None:
if exception:
logger.warning(f"Error retrieving file {request_id}: {exception}")
else:
results[request_id] = response
results: Dict[str, GoogleDriveFileType] = {}
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
for web_view_link in web_view_links:
try:
file_id = _extract_file_id_from_web_view_link(web_view_link)
request = service.files().get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
batch.add(request, request_id=web_view_link)
except ValueError as e:
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
batch.execute()
return results

View File

@@ -269,3 +269,15 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
checkpoint: CT,
) -> CheckpointOutput[CT]:
raise NotImplementedError
class Resolver:
@abc.abstractmethod
def resolve_errors(
self, errors: list[ConnectorFailure], include_permissions: bool = False
) -> Generator[Document | ConnectorFailure, None, None]:
"""Attempts to yield back ALL the documents described by the error, no checkpointing.
caller's responsibility is to delete the old connectorfailures and replace with the new ones.
If include_permissions is True, the documents will have permissions synced.
"""
raise NotImplementedError