Compare commits

...

2 Commits

Author SHA1 Message Date
Evan Lohn
ca9e5cf935 remove future stuff 2025-12-29 16:34:01 -08:00
Evan Lohn
d1c55e7f11 refactor: drive connector 2025-12-29 16:34:01 -08:00
2 changed files with 116 additions and 90 deletions

View File

@@ -8,7 +8,6 @@ from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any
from typing import cast
from typing import Protocol
@@ -154,7 +153,8 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
):
def __init__(
self,
@@ -1130,26 +1130,67 @@ class GoogleDriveConnector(
end=end,
)
def _extract_docs_from_google_drive(
def _convert_retrieved_files_to_documents(
self,
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
drive_files_iter: Iterator[RetrievedDriveFile],
include_permissions: bool,
) -> Iterator[Document | ConnectorFailure]:
"""
Retrieves and converts Google Drive files to documents.
Converts retrieved files to documents.
"""
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
files_batch: list[RetrievedDriveFile] = []
for retrieved_file in drive_files_iter:
if self.exclude_domain_link_only and has_link_only_permission(
retrieved_file.drive_file
):
continue
if retrieved_file.error is None:
files_batch.append(retrieved_file)
continue
# handle retrieval errors
failure_stage = retrieved_file.completion_stage.value
failure_message = f"retrieval failure during stage: {failure_stage},"
failure_message += f"user: {retrieved_file.user_email},"
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
self._convert_retrieved_file_to_document,
(retrieved_file, include_permissions),
)
for retrieved_file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
results = [result for result in results if result is not None]
logger.debug(f"batch has {len(results)} docs or failures")
yield from results
def _convert_retrieved_file_to_document(
self,
retrieved_file: RetrievedDriveFile,
include_permissions: bool,
) -> Document | ConnectorFailure | None:
"""
Converts a retrieved file to a document.
"""
try:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
convert_drive_item_to_document,
return convert_drive_item_to_document(
self.creds,
self.allow_images,
self.size_threshold,
@@ -1161,83 +1202,15 @@ class GoogleDriveConnector(
if include_permissions
else None
),
)
# Fetch files in batches
batches_complete = 0
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
convert_func,
(
[file.user_email, self.primary_admin_email]
+ get_file_owners(
file.drive_file, self.primary_admin_email
),
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
logger.debug(
f"finished processing batch {batches_complete} with {len(results)} results"
)
docs_and_failures = [result for result in results if result is not None]
logger.debug(
f"batch {batches_complete} has {len(docs_and_failures)} docs or failures"
)
if docs_and_failures:
yield from docs_and_failures
batches_complete += 1
logger.debug(f"finished yielding batch {batches_complete}")
for retrieved_file in self._fetch_drive_items(
field_type=field_type,
checkpoint=checkpoint,
start=start,
end=end,
):
if self.exclude_domain_link_only and has_link_only_permission(
retrieved_file.drive_file
):
continue
if retrieved_file.error is None:
files_batch.append(retrieved_file)
continue
# handle retrieval errors
failure_stage = retrieved_file.completion_stage.value
failure_message = f"retrieval failure during stage: {failure_stage},"
failure_message += f"user: {retrieved_file.user_email},"
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
yield from _yield_batch(files_batch)
checkpoint.retrieved_folder_and_drive_ids = (
self._retrieved_folder_and_drive_ids
[retrieved_file.user_email, self.primary_admin_email]
+ get_file_owners(retrieved_file.drive_file, self.primary_admin_email),
retrieved_file.drive_file,
)
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
logger.exception(
f"Error extracting document: {retrieved_file.drive_file.get('name')} from Google Drive"
)
raise e
def _load_from_checkpoint(
@@ -1262,8 +1235,19 @@ class GoogleDriveConnector(
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
try:
yield from self._extract_docs_from_google_drive(
checkpoint, start, end, include_permissions
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
drive_files_iter = self._fetch_drive_items(
field_type=field_type,
checkpoint=checkpoint,
start=start,
end=end,
)
yield from self._convert_retrieved_files_to_documents(
drive_files_iter, include_permissions
)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):

View File

@@ -3,6 +3,8 @@ from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from enum import Enum
from urllib.parse import parse_qs
from urllib.parse import urlparse
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
@@ -393,3 +395,43 @@ def get_root_folder_id(service: Resource) -> str:
.get(fileId="root", fields=GoogleFields.ID.value)
.execute()[GoogleFields.ID.value]
)
def _extract_file_id_from_web_view_link(web_view_link: str) -> str:
parsed = urlparse(web_view_link)
path_parts = [part for part in parsed.path.split("/") if part]
if "d" in path_parts:
idx = path_parts.index("d")
if idx + 1 < len(path_parts):
return path_parts[idx + 1]
query_params = parse_qs(parsed.query)
for key in ("id", "fileId"):
value = query_params.get(key)
if value and value[0]:
return value[0]
raise ValueError(
f"Unable to extract Drive file id from webViewLink: {web_view_link}"
)
def get_file_by_web_view_link(
service: GoogleDriveService,
web_view_link: str,
fields: str,
) -> GoogleDriveFileType:
"""
Retrieve a Google Drive file using its webViewLink.
"""
file_id = _extract_file_id_from_web_view_link(web_view_link)
return (
service.files()
.get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
.execute()
)