Compare commits

..

15 Commits

Author SHA1 Message Date
Evan Lohn
09340dafd9 100 items per page 2025-03-21 17:50:13 -07:00
Evan Lohn
7dedaac090 fixed tests 2025-03-21 17:43:00 -07:00
Evan Lohn
10810d4a20 remove prints 2025-03-21 16:10:42 -07:00
Evan Lohn
4fbf818e04 validation fix 2025-03-20 17:03:14 -07:00
Evan Lohn
a83f531392 validation fix 2025-03-20 16:55:47 -07:00
Evan Lohn
a22259a60d address CW comments 2025-03-20 15:55:43 -07:00
Evan Lohn
a81338a036 connector failures 2025-03-20 12:05:05 -07:00
Evan Lohn
bc558a0549 unit tests and bug fix 2025-03-20 12:05:05 -07:00
Evan Lohn
6dbe9136e9 secrets cant start with GITHUB_ 2025-03-20 12:05:05 -07:00
Evan Lohn
d693f6323d connector test env var 2025-03-20 12:05:05 -07:00
Evan Lohn
fc7c9b3c68 github basic connector test 2025-03-20 12:05:05 -07:00
Evan Lohn
85c55e2270 CW comments 2025-03-20 12:05:05 -07:00
Evan Lohn
13863a3255 nit 2025-03-20 12:05:05 -07:00
Evan Lohn
61bb69e09a first draft of github checkpointing 2025-03-20 12:05:05 -07:00
Evan Lohn
d4ae522014 WIP github checkpointing 2025-03-20 12:05:05 -07:00
23 changed files with 776 additions and 264 deletions

View File

@@ -45,6 +45,8 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}

View File

@@ -65,17 +65,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
app.state.gpu_type = gpu_type
try:
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
except Exception as e:
logger.warning(
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
"This is not a critical error and the model server will continue to run."
)
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")

View File

@@ -435,7 +435,7 @@ def _run_indexing(
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint

View File

@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -88,9 +88,7 @@ class Answer:
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = (
fast_gpu_status_request(indexing=False) or using_cloud_reranking
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.

View File

@@ -157,10 +157,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))

View File

@@ -114,7 +114,6 @@ class ConfluenceConnector(
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
@@ -159,9 +158,6 @@ class ConfluenceConnector(
"max_backoff_seconds": 60,
}
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -237,9 +233,7 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
page_url = f"{self.wiki_base}{page['_links']['webui']}"
# Get the page content
page_content = extract_text_from_confluence_html(
@@ -270,7 +264,6 @@ class ConfluenceConnector(
self.confluence_client,
attachment,
page_id,
self.allow_images,
)
if result and result.text:
@@ -311,14 +304,13 @@ class ConfluenceConnector(
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
email = author.get("email", "unknown@domain.invalid")
primary_owners.append(
BasicExpertInfo(display_name=display_name, email=email)
)
primary_owners.append(BasicExpertInfo(display_name=display_name))
# Create the document
return Document(
id=page_url,
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -381,7 +373,6 @@ class ConfluenceConnector(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue

View File

@@ -112,7 +112,6 @@ def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
allow_images: bool,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
@@ -120,7 +119,7 @@ def process_attachment(
"""
try:
# Get the media type from the attachment metadata
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
@@ -139,14 +138,7 @@ def process_attachment(
attachment_size = attachment["extensions"]["fileSize"]
if media_type.startswith("image/"):
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_name=None,
error="Image downloading is not enabled",
)
else:
if not media_type.startswith("image/"):
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
@@ -302,7 +294,6 @@ def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
@@ -318,7 +309,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id, allow_images)
result = process_attachment(confluence_client, attachment, page_id)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@@ -185,8 +184,6 @@ def instantiate_connector(
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
return connector

View File

@@ -1,8 +1,10 @@
import copy
import time
from collections.abc import Iterator
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
ITEMS_PER_PAGE = 100
_MAX_NUM_RATE_LIMIT_RETRIES = 5
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
def _get_batch_rate_limited(
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
) -> list[Any]:
) -> list[PullRequest | Issue]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
)
def _batch_github_objects(
git_objs: PaginatedList, github_client: Github, batch_size: int
) -> Iterator[list[Any]]:
page_num = 0
while True:
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
page_num += 1
if not batch:
break
for mini_batch in batch_generator(batch, batch_size=batch_size):
yield mini_batch
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None,
metadata={
"merged": str(pull_request.merged),
"state": pull_request.state,
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class GithubConnector(LoadConnector, PollConnector):
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
self.github_client = (
Github(
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
credentials["github_access_token"],
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(credentials["github_access_token"])
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
)
return None
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
# All repositories
repos = self._get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
checkpoint.cached_repo = SerializedRepository(
id=checkpoint.cached_repo_ids[0],
headers=repos[0].raw_headers,
raw_data=repos[0].raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch: list[Document] = []
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
return checkpoint
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch = []
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
return checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
adjusted_start_datetime = start_datetime - timedelta(hours=3)
epoch = datetime.utcfromtimestamp(0)
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
)
def validate_connector_settings(self) -> None:
if self.github_client is None:
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
)
if __name__ == "__main__":
import os
@@ -406,7 +544,9 @@ if __name__ == "__main__":
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
document_batches = connector.load_from_checkpoint(
0, time.time(), connector.build_dummy_checkpoint()
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -86,7 +86,6 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
@@ -102,7 +101,6 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
allow_images=allow_images,
)
@@ -236,10 +234,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
self.allow_images = False
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def primary_admin_email(self) -> str:
@@ -906,7 +900,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
)
# Fetch files in batches
@@ -1104,9 +1097,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_service.files().list(pageSize=1, fields="files(id)").execute()
if isinstance(self._creds, ServiceAccountCredentials):
# default is ~17mins of retries, don't do that here since this is called from
# the UI
retry_builder(tries=3, delay=0.1)(get_root_folder_id)(drive_service)
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
status_code = e.resp.status if e.resp else None

View File

@@ -79,7 +79,6 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
@@ -88,10 +87,6 @@ def _extract_sections_basic(
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
@@ -212,7 +207,6 @@ def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
@@ -242,7 +236,7 @@ def convert_drive_item_to_document(
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service(), allow_images)
sections = _extract_sections_basic(file, drive_service())
# If we still don't have any sections, skip this file
if not sections:

View File

@@ -1,7 +1,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
@@ -37,12 +36,12 @@ def _generate_time_range_filter(
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += (
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
)
if end is not None:
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
return time_range_filter

View File

@@ -17,12 +17,9 @@ logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
#
# NOTE: We previously tried to combat this here by adding a very
# long retry period (~20 minutes of trying, one request a minute.)
# This is no longer necessary due to checkpointing.
add_retries = retry_builder(tries=5, max_delay=10)
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
@@ -40,14 +37,14 @@ class GoogleFields(str, Enum):
def _execute_with_retry(request: Any) -> Any:
max_attempts = 6
max_attempts = 10
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Drive/Gmail API with the same key
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...

View File

@@ -60,10 +60,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
Default is a no-op (always successful).
"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore

View File

@@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request(indexing=True)
gpu_available = gpu_status_request()
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)

View File

@@ -1,5 +1,3 @@
from functools import lru_cache
import requests
from retry import retry
@@ -12,7 +10,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool = True) -> bool:
if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else:
@@ -29,14 +28,3 @@ def _get_gpu_status_from_model_server(indexing: bool) -> bool:
except requests.RequestException as e:
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
raise # Re-raise exception to trigger a retry
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool) -> bool:
return _get_gpu_status_from_model_server(indexing)
@lru_cache(maxsize=1)
def fast_gpu_status_request(indexing: bool) -> bool:
"""For use in sync flows, where we don't want to retry / we want to cache this."""
return gpu_status_request(indexing=indexing)

View File

@@ -56,7 +56,7 @@ puremagic==1.28
pyairtable==3.0.1
pycryptodome==3.19.1
pydantic==2.8.2
PyGithub==1.58.2
PyGithub==2.5.0
python-dateutil==2.8.2
python-gitlab==3.9.0
python-pptx==0.6.23

View File

@@ -1,6 +1,5 @@
import os
import time
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -8,16 +7,15 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.utils import AttachmentProcessingResult
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
@pytest.fixture
def confluence_connector(space: str) -> ConfluenceConnector:
def confluence_connector() -> ConfluenceConnector:
connector = ConfluenceConnector(
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
space=space,
space=os.environ["CONFLUENCE_TEST_SPACE"],
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
)
@@ -34,15 +32,14 @@ def confluence_connector(space: str) -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
@@ -53,14 +50,15 @@ def test_confluence_connector_basic(
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None
for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
elif ".txt" in doc.semantic_identifier:
txt_doc = doc
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc
else:
pass
assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
@@ -81,7 +79,7 @@ def test_confluence_connector_basic(
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
assert len(page_doc.sections) == 2 # page text + attachment text
assert len(page_doc.sections) == 1
page_section = page_doc.sections[0]
assert page_section.text == "test123 " + page_within_a_page_text
@@ -90,65 +88,13 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)
text_attachment_section = page_doc.sections[1]
assert text_attachment_section.text == "small"
assert text_attachment_section.link
assert text_attachment_section.link.endswith("small-file.txt")
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_confluence_connector_skip_images(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 8
def mock_process_image_attachment(
*args: Any, **kwargs: Any
) -> AttachmentProcessingResult:
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
be done as a rule to begin with, but life is not perfect. Fix it later"""
return AttachmentProcessingResult(
text="Hi_text",
file_name="Hi_filename",
error=None,
assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
assert txt_doc.primary_owners
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
assert (
txt_doc.sections[0].link
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
)
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@patch(
"onyx.connectors.confluence.utils._process_image_attachment",
side_effect=mock_process_image_attachment,
)
def test_confluence_connector_allow_images(
mock_get_api_key: MagicMock,
mock_process_image_attachment: MagicMock,
confluence_connector: ConfluenceConnector,
) -> None:
confluence_connector.set_allow_images(True)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 12

View File

@@ -0,0 +1,54 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import GithubConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
def github_connector() -> GithubConnector:
connector = GithubConnector(
repo_owner="onyx-dot-app",
repositories="documentation",
include_prs=True,
include_issues=True,
)
connector.load_credentials(
{
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
}
)
return connector
def test_github_connector_basic(github_connector: GithubConnector) -> None:
docs = load_all_docs_from_checkpoint_connector(
connector=github_connector,
start=0,
end=time.time(),
)
assert len(docs) > 0 # We expect at least one PR to exist
# Test the first document's structure
doc = docs[0]
# Verify basic document properties
assert doc.source == DocumentSource.GITHUB
assert doc.secondary_owners is None
assert doc.from_ingestion_api is False
assert doc.additional_info is None
# Verify GitHub-specific properties
assert "github.com" in doc.id # Should be a GitHub URL
assert doc.metadata is not None
assert "state" in doc.metadata
assert "merged" in doc.metadata
# Verify sections
assert len(doc.sections) == 1
section = doc.sections[0]
assert section.link == doc.id # Section link should match document ID
assert isinstance(section.text, str) # Should have some text content

View File

@@ -50,7 +50,7 @@ def answer_instance(
mocker: MockerFixture,
) -> Answer:
mocker.patch(
"onyx.chat.answer.fast_gpu_status_request",
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
@@ -400,7 +400,7 @@ def test_no_slow_reranking(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.fast_gpu_status_request",
"onyx.chat.answer.gpu_status_request",
return_value=gpu_enabled,
)
rerank_settings = (

View File

@@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.fast_gpu_status_request",
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
question = config["question"]

View File

@@ -0,0 +1,441 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from github import Github
from github import GithubException
from github import RateLimitExceededException
from github.Issue import Issue
from github.PullRequest import PullRequest
from github.RateLimit import RateLimit
from github.Repository import Repository
from github.Requester import Requester
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.github.connector import SerializedRepository
from onyx.connectors.models import Document
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
@pytest.fixture
def repo_owner() -> str:
return "test-org"
@pytest.fixture
def repositories() -> str:
return "test-repo"
@pytest.fixture
def mock_github_client() -> MagicMock:
"""Create a mock GitHub client with proper typing"""
mock = MagicMock(spec=Github)
# Add proper return typing for get_repo method
mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository))
# Add proper return typing for get_organization method
mock.get_organization = MagicMock()
# Add proper return typing for get_user method
mock.get_user = MagicMock()
# Add proper return typing for get_rate_limit method
mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit))
# Add requester for repository deserialization
mock.requester = MagicMock(spec=Requester)
return mock
@pytest.fixture
def github_connector(
repo_owner: str, repositories: str, mock_github_client: MagicMock
) -> Generator[GithubConnector, None, None]:
connector = GithubConnector(
repo_owner=repo_owner,
repositories=repositories,
include_prs=True,
include_issues=True,
)
connector.github_client = mock_github_client
yield connector
@pytest.fixture
def create_mock_pr() -> Callable[..., MagicMock]:
def _create_mock_pr(
number: int = 1,
title: str = "Test PR",
body: str = "Test Description",
state: str = "open",
merged: bool = False,
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
) -> MagicMock:
"""Helper to create a mock PullRequest object"""
mock_pr = MagicMock(spec=PullRequest)
mock_pr.number = number
mock_pr.title = title
mock_pr.body = body
mock_pr.state = state
mock_pr.merged = merged
mock_pr.updated_at = updated_at
mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}"
return mock_pr
return _create_mock_pr
@pytest.fixture
def create_mock_issue() -> Callable[..., MagicMock]:
def _create_mock_issue(
number: int = 1,
title: str = "Test Issue",
body: str = "Test Description",
state: str = "open",
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
) -> MagicMock:
"""Helper to create a mock Issue object"""
mock_issue = MagicMock(spec=Issue)
mock_issue.number = number
mock_issue.title = title
mock_issue.body = body
mock_issue.state = state
mock_issue.updated_at = updated_at
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
mock_issue.pull_request = None # Not a PR
return mock_issue
return _create_mock_issue
@pytest.fixture
def create_mock_repo() -> Callable[..., MagicMock]:
def _create_mock_repo(
name: str = "test-repo",
id: int = 1,
) -> MagicMock:
"""Helper to create a mock Repository object"""
mock_repo = MagicMock(spec=Repository)
mock_repo.name = name
mock_repo.id = id
mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"}
mock_repo.raw_data = {
"id": str(id),
"name": name,
"full_name": f"test-org/{name}",
"private": str(False),
"description": "Test repository",
}
return mock_repo
return _create_mock_repo
def test_load_from_checkpoint_happy_path(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
create_mock_issue: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint - happy path"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs and issues
mock_pr1 = create_mock_pr(number=1, title="PR 1")
mock_pr2 = create_mock_pr(number=2, title="PR 2")
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
# Mock get_pulls and get_issues methods
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.side_effect = [
[mock_pr1, mock_pr2],
[],
]
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.side_effect = [
[mock_issue1, mock_issue2],
[],
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we got all documents and final has_more=False
assert len(outputs) == 4
repo_batch = outputs[0]
assert len(repo_batch.items) == 0
assert repo_batch.next_checkpoint.has_more is True
# Check first batch (PRs)
first_batch = outputs[1]
assert len(first_batch.items) == 2
assert isinstance(first_batch.items[0], Document)
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert isinstance(first_batch.items[1], Document)
assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2"
assert first_batch.next_checkpoint.curr_page == 1
# Check second batch (Issues)
second_batch = outputs[2]
assert len(second_batch.items) == 2
assert isinstance(second_batch.items[0], Document)
assert (
second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1"
)
assert isinstance(second_batch.items[1], Document)
assert (
second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2"
)
assert second_batch.next_checkpoint.has_more
# Check third batch (finished checkpoint)
third_batch = outputs[3]
assert len(third_batch.items) == 0
assert third_batch.next_checkpoint.has_more is False
def test_load_from_checkpoint_with_rate_limit(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with rate limit handling"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PR
mock_pr = create_mock_pr()
# Mock get_pulls to raise RateLimitExceededException on first call
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.side_effect = [
RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}),
[mock_pr],
[],
]
# Mock rate limit reset time
mock_rate_limit = MagicMock(spec=RateLimit)
mock_rate_limit.core.reset = datetime.now(timezone.utc)
github_connector.github_client.get_rate_limit.return_value = mock_rate_limit
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
with patch(
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
) as mock_sleep:
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
assert mock_sleep.call_count == 1
# Check that we got the document after rate limit was handled
assert len(outputs) >= 2
assert len(outputs[1].items) == 1
assert isinstance(outputs[1].items[0], Document)
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
assert outputs[-1].next_checkpoint.has_more is False
def test_load_from_checkpoint_with_empty_repo(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with an empty repository"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Mock get_pulls and get_issues to return empty lists
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.return_value = []
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.return_value = []
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we got no documents
assert len(outputs) == 2
assert len(outputs[-1].items) == 0
assert not outputs[-1].next_checkpoint.has_more
def test_load_from_checkpoint_with_prs_only(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_pr: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with only PRs enabled"""
# Configure connector to only include PRs
github_connector.include_prs = True
github_connector.include_issues = False
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked PRs
mock_pr1 = create_mock_pr(number=1, title="PR 1")
mock_pr2 = create_mock_pr(number=2, title="PR 2")
# Mock get_pulls method
mock_repo.get_pulls.return_value = MagicMock()
mock_repo.get_pulls.return_value.get_page.side_effect = [
[mock_pr1, mock_pr2],
[],
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we only got PRs
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert all(
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
) # All documents should be PRs
assert outputs[-1].next_checkpoint.has_more is False
def test_load_from_checkpoint_with_issues_only(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
create_mock_issue: Callable[..., MagicMock],
) -> None:
"""Test loading from checkpoint with only issues enabled"""
# Configure connector to only include issues
github_connector.include_prs = False
github_connector.include_issues = True
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Set up mocked issues
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
# Mock get_issues method
mock_repo.get_issues.return_value = MagicMock()
mock_repo.get_issues.return_value.get_page.side_effect = [
[mock_issue1, mock_issue2],
[],
]
# Mock SerializedRepository.to_Repository to return our mock repo
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
)
# Check that we only got issues
assert len(outputs) >= 2
assert len(outputs[1].items) == 2
assert all(
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
) # All documents should be issues
assert outputs[1].next_checkpoint.has_more
assert outputs[-1].next_checkpoint.has_more is False
@pytest.mark.parametrize(
"status_code,expected_exception,expected_message",
[
(
401,
CredentialExpiredError,
"GitHub credential appears to be invalid or expired",
),
(
403,
InsufficientPermissionsError,
"Your GitHub token does not have sufficient permissions",
),
(
404,
ConnectorValidationError,
"GitHub repository not found",
),
],
)
def test_validate_connector_settings_errors(
github_connector: GithubConnector,
status_code: int,
expected_exception: type[Exception],
expected_message: str,
) -> None:
"""Test validation with various error scenarios"""
error = GithubException(status=status_code, data={}, headers={})
github_client = cast(Github, github_connector.github_client)
get_repo_mock = cast(MagicMock, github_client.get_repo)
get_repo_mock.side_effect = error
with pytest.raises(expected_exception) as excinfo:
github_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
def test_validate_connector_settings_success(
github_connector: GithubConnector,
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
) -> None:
"""Test successful validation"""
# Set up mocked repo
mock_repo = create_mock_repo()
github_connector.github_client = mock_github_client
mock_github_client.get_repo.return_value = mock_repo
# Mock get_contents to simulate successful access
mock_repo.get_contents.return_value = MagicMock()
github_connector.validate_connector_settings()
github_connector.github_client.get_repo.assert_called_once_with(
f"{github_connector.repo_owner}/{github_connector.repositories}"
)

View File

@@ -1383,7 +1383,6 @@ export function ChatPage({
if (!packet) {
continue;
}
console.log("Packet:", JSON.stringify(packet));
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
@@ -1729,7 +1728,6 @@ export function ChatPage({
}
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;
upsertToCompleteMessageMap({
messages: [
@@ -1757,13 +1755,11 @@ export function ChatPage({
completeMessageMapOverride: currentMessageMap(completeMessageDetail),
});
}
console.log("Finished streaming");
setAgenticGenerating(false);
resetRegenerationState(currentSessionId());
updateChatState("input");
if (isNewSession) {
console.log("Setting up new session");
if (finalMessage) {
setSelectedMessageForDocDisplay(finalMessage.message_id);
}