mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
5 Commits
v2.12.0-cl
...
feat/resol
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d81337e345 | ||
|
|
df3c8982a1 | ||
|
|
dbb720e7f9 | ||
|
|
f68b9526fb | ||
|
|
6460b5df4b |
@@ -564,7 +564,7 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
|
||||
@@ -8,7 +8,6 @@ from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
@@ -39,6 +38,9 @@ from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
@@ -64,6 +66,7 @@ from onyx.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -154,7 +157,9 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1130,26 +1135,67 @@ class GoogleDriveConnector(
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
def _convert_retrieved_files_to_documents(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
drive_files_iter: Iterator[RetrievedDriveFile],
|
||||
include_permissions: bool,
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
"""
|
||||
Retrieves and converts Google Drive files to documents.
|
||||
Converts retrieved files to documents.
|
||||
"""
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
for retrieved_file in drive_files_iter:
|
||||
if self.exclude_domain_link_only and has_link_only_permission(
|
||||
retrieved_file.drive_file
|
||||
):
|
||||
continue
|
||||
if retrieved_file.error is None:
|
||||
files_batch.append(retrieved_file)
|
||||
continue
|
||||
# handle retrieval errors
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = f"retrieval failure during stage: {failure_stage},"
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(retrieved_file, include_permissions),
|
||||
)
|
||||
for retrieved_file in files_batch
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
|
||||
results_cleaned = [result for result in results if result is not None]
|
||||
logger.debug(f"batch has {len(results_cleaned)} docs or failures")
|
||||
|
||||
yield from results_cleaned
|
||||
|
||||
def _convert_retrieved_file_to_document(
|
||||
self,
|
||||
retrieved_file: RetrievedDriveFile,
|
||||
include_permissions: bool,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Converts a retrieved file to a document.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
convert_drive_item_to_document,
|
||||
return convert_drive_item_to_document(
|
||||
self.creds,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
@@ -1161,83 +1207,15 @@ class GoogleDriveConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[RetrievedDriveFile],
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
nonlocal batches_complete
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
[file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(
|
||||
file.drive_file, self.primary_admin_email
|
||||
),
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
for file in files_batch
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
logger.debug(
|
||||
f"finished processing batch {batches_complete} with {len(results)} results"
|
||||
)
|
||||
|
||||
docs_and_failures = [result for result in results if result is not None]
|
||||
logger.debug(
|
||||
f"batch {batches_complete} has {len(docs_and_failures)} docs or failures"
|
||||
)
|
||||
|
||||
if docs_and_failures:
|
||||
yield from docs_and_failures
|
||||
batches_complete += 1
|
||||
logger.debug(f"finished yielding batch {batches_complete}")
|
||||
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
field_type=field_type,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self.exclude_domain_link_only and has_link_only_permission(
|
||||
retrieved_file.drive_file
|
||||
):
|
||||
continue
|
||||
if retrieved_file.error is None:
|
||||
files_batch.append(retrieved_file)
|
||||
continue
|
||||
|
||||
# handle retrieval errors
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = f"retrieval failure during stage: {failure_stage},"
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
yield from _yield_batch(files_batch)
|
||||
checkpoint.retrieved_folder_and_drive_ids = (
|
||||
self._retrieved_folder_and_drive_ids
|
||||
[retrieved_file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(retrieved_file.drive_file, self.primary_admin_email),
|
||||
retrieved_file.drive_file,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
logger.exception(
|
||||
f"Error extracting document: {retrieved_file.drive_file.get('name')} from Google Drive"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
@@ -1262,8 +1240,19 @@ class GoogleDriveConnector(
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end, include_permissions
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
drive_files_iter = self._fetch_drive_items(
|
||||
field_type=field_type,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
yield from self._convert_retrieved_files_to_documents(
|
||||
drive_files_iter, include_permissions
|
||||
)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
@@ -1300,6 +1289,43 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self, errors: list[ConnectorFailure], include_permissions: bool = False
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the error, no checkpointing.
|
||||
caller's responsibility is to delete the old connectorfailures and replace with the new ones.
|
||||
"""
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = set(
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
)
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
files = get_files_by_web_view_links_batch(service, list(doc_ids), field_type)
|
||||
retrieved_iter = (
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in files.values()
|
||||
)
|
||||
yield from self._convert_retrieved_files_to_documents(
|
||||
retrieved_iter, include_permissions
|
||||
)
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
|
||||
@@ -3,9 +3,14 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
@@ -52,6 +57,8 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -393,3 +400,98 @@ def get_root_folder_id(service: Resource) -> str:
|
||||
.get(fileId="root", fields=GoogleFields.ID.value)
|
||||
.execute()[GoogleFields.ID.value]
|
||||
)
|
||||
|
||||
|
||||
def _extract_file_id_from_web_view_link(web_view_link: str) -> str:
|
||||
parsed = urlparse(web_view_link)
|
||||
path_parts = [part for part in parsed.path.split("/") if part]
|
||||
|
||||
if "d" in path_parts:
|
||||
idx = path_parts.index("d")
|
||||
if idx + 1 < len(path_parts):
|
||||
return path_parts[idx + 1]
|
||||
|
||||
query_params = parse_qs(parsed.query)
|
||||
for key in ("id", "fileId"):
|
||||
value = query_params.get(key)
|
||||
if value and value[0]:
|
||||
return value[0]
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to extract Drive file id from webViewLink: {web_view_link}"
|
||||
)
|
||||
|
||||
|
||||
def get_file_by_web_view_link(
|
||||
service: GoogleDriveService,
|
||||
web_view_link: str,
|
||||
fields: str,
|
||||
) -> GoogleDriveFileType:
|
||||
"""
|
||||
Retrieve a Google Drive file using its webViewLink.
|
||||
"""
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
return (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> dict[str, GoogleDriveFileType]:
|
||||
fields = _get_fields_for_file_type(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
ret = {}
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
batch = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
ret.update(_get_files_by_web_view_links_batch(service, batch, fields))
|
||||
return ret
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> dict[str, GoogleDriveFileType]:
|
||||
"""
|
||||
Retrieve multiple Google Drive files using their webViewLinks in a single batch request.
|
||||
|
||||
Returns a dict mapping web_view_link to file metadata.
|
||||
Failed requests (due to invalid links or API errors) are omitted from the result.
|
||||
"""
|
||||
|
||||
def callback(
|
||||
request_id: str, response: GoogleDriveFileType, exception: Exception | None
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
else:
|
||||
results[request_id] = response
|
||||
|
||||
results: Dict[str, GoogleDriveFileType] = {}
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
|
||||
batch.execute()
|
||||
return results
|
||||
|
||||
@@ -269,3 +269,15 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
checkpoint: CT,
|
||||
) -> CheckpointOutput[CT]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver:
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self, errors: list[ConnectorFailure], include_permissions: bool = False
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the error, no checkpointing.
|
||||
caller's responsibility is to delete the old connectorfailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user