mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-16 21:22:41 +00:00
Compare commits
4 Commits
v3.0.0
...
jtahara/gi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c57dc54cc | ||
|
|
67fc1ce9e0 | ||
|
|
0af7f8c44e | ||
|
|
70ed680c83 |
@@ -9,6 +9,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from github import Github
|
||||
from github import GithubIntegration
|
||||
from github import RateLimitExceededException
|
||||
from github import Repository
|
||||
from github.GithubException import GithubException
|
||||
@@ -28,6 +29,8 @@ from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.connectors.github.rate_limit_utils import raise_if_approaching_rate_limit
|
||||
from onyx.connectors.github.rate_limit_utils import RateLimitBudgetLow
|
||||
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
|
||||
from onyx.connectors.github.utils import deserialize_repository
|
||||
from onyx.connectors.github.utils import get_external_access_permission
|
||||
@@ -172,6 +175,7 @@ def _get_batch_rate_limited(
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
raise_if_approaching_rate_limit(github_client)
|
||||
try:
|
||||
if cursor_url:
|
||||
# when this is set, we are resuming from an earlier
|
||||
@@ -418,18 +422,66 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
self.github_client: Github | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# defaults to 30 items per page, can be set to as high as 100
|
||||
self.github_client = (
|
||||
Github(
|
||||
credentials["github_access_token"],
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
||||
)
|
||||
self.github_client = self._build_github_client(credentials)
|
||||
return None
|
||||
|
||||
def _build_github_client(self, credentials: dict[str, Any]) -> Github:
|
||||
"""
|
||||
Build a Github client from either a PAT or a GitHub App installation.
|
||||
"""
|
||||
github_access_token = credentials.get("github_access_token")
|
||||
if github_access_token:
|
||||
return (
|
||||
Github(
|
||||
github_access_token,
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(github_access_token, per_page=ITEMS_PER_PAGE)
|
||||
)
|
||||
|
||||
app_id = credentials.get("github_app_id")
|
||||
installation_id = credentials.get("github_app_installation_id")
|
||||
private_key = credentials.get("github_app_private_key")
|
||||
|
||||
# Require all GitHub App fields if using that auth path
|
||||
if app_id or installation_id or private_key:
|
||||
if not (app_id and installation_id and private_key):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"GitHub App authentication requires app_id, installation_id, and private_key."
|
||||
)
|
||||
try:
|
||||
app_id_int = int(app_id)
|
||||
installation_id_int = int(installation_id)
|
||||
except (TypeError, ValueError):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"GitHub App credentials must include numeric app and installation IDs."
|
||||
)
|
||||
|
||||
integration = (
|
||||
GithubIntegration(
|
||||
app_id_int, private_key, base_url=GITHUB_CONNECTOR_BASE_URL
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else GithubIntegration(app_id_int, private_key)
|
||||
)
|
||||
app_token = integration.get_access_token(installation_id_int).token
|
||||
|
||||
return (
|
||||
Github(
|
||||
app_token,
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(app_token, per_page=ITEMS_PER_PAGE)
|
||||
)
|
||||
|
||||
raise ConnectorMissingCredentialError(
|
||||
"GitHub credentials not loaded. Provide a PAT or GitHub App credentials."
|
||||
)
|
||||
|
||||
def get_github_repo(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> Repository.Repository:
|
||||
@@ -567,14 +619,22 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
self._pull_requests_func(repo),
|
||||
checkpoint.curr_page,
|
||||
checkpoint.cursor_url,
|
||||
checkpoint.num_retrieved,
|
||||
cursor_url_callback,
|
||||
self.github_client,
|
||||
)
|
||||
try:
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
self._pull_requests_func(repo),
|
||||
checkpoint.curr_page,
|
||||
checkpoint.cursor_url,
|
||||
checkpoint.num_retrieved,
|
||||
cursor_url_callback,
|
||||
self.github_client,
|
||||
)
|
||||
except RateLimitBudgetLow as e:
|
||||
logger.info(
|
||||
"Stopping GitHub fetch early to avoid hitting rate limit "
|
||||
f"(remaining={e.remaining}, threshold={e.threshold}, "
|
||||
f"resets_at={e.reset_at.isoformat()}, seconds_until_reset={e.seconds_until_reset:.0f})."
|
||||
)
|
||||
return checkpoint
|
||||
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
|
||||
done_with_prs = False
|
||||
num_prs = 0
|
||||
@@ -640,16 +700,24 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
|
||||
issue_batch = list(
|
||||
_get_batch_rate_limited(
|
||||
self._issues_func(repo),
|
||||
checkpoint.curr_page,
|
||||
checkpoint.cursor_url,
|
||||
checkpoint.num_retrieved,
|
||||
cursor_url_callback,
|
||||
self.github_client,
|
||||
try:
|
||||
issue_batch = list(
|
||||
_get_batch_rate_limited(
|
||||
self._issues_func(repo),
|
||||
checkpoint.curr_page,
|
||||
checkpoint.cursor_url,
|
||||
checkpoint.num_retrieved,
|
||||
cursor_url_callback,
|
||||
self.github_client,
|
||||
)
|
||||
)
|
||||
)
|
||||
except RateLimitBudgetLow as e:
|
||||
logger.info(
|
||||
"Stopping GitHub fetch early to avoid hitting rate limit "
|
||||
f"(remaining={e.remaining}, threshold={e.threshold}, "
|
||||
f"resets_at={e.reset_at.isoformat()}, seconds_until_reset={e.seconds_until_reset:.0f})."
|
||||
)
|
||||
return checkpoint
|
||||
logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -10,6 +11,64 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_minimum_remaining_threshold() -> int:
|
||||
"""Read the configurable remaining-budget threshold from env."""
|
||||
default_threshold = 100
|
||||
try:
|
||||
return max(
|
||||
int(
|
||||
os.environ.get("GITHUB_RATE_LIMIT_MINIMUM_REMAINING", default_threshold)
|
||||
),
|
||||
0,
|
||||
)
|
||||
except ValueError:
|
||||
return default_threshold
|
||||
|
||||
|
||||
MINIMUM_RATE_LIMIT_REMAINING = _load_minimum_remaining_threshold()
|
||||
|
||||
|
||||
class RateLimitBudgetLow(Exception):
|
||||
"""Raised when we're close enough to the rate limit that we should pause early."""
|
||||
|
||||
def __init__(
|
||||
self, remaining: int, threshold: int, reset_at: datetime, *args: object
|
||||
) -> None:
|
||||
super().__init__(*args)
|
||||
self.remaining = remaining
|
||||
self.threshold = threshold
|
||||
self.reset_at = reset_at
|
||||
self.seconds_until_reset = max(
|
||||
0, (reset_at - datetime.now(tz=timezone.utc)).total_seconds()
|
||||
)
|
||||
|
||||
|
||||
def raise_if_approaching_rate_limit(
|
||||
github_client: Github, minimum_remaining: int | None = None
|
||||
) -> None:
|
||||
"""Raise if the client is close to its rate limit to avoid long sleeps."""
|
||||
threshold = (
|
||||
minimum_remaining
|
||||
if minimum_remaining is not None
|
||||
else MINIMUM_RATE_LIMIT_REMAINING
|
||||
)
|
||||
if threshold <= 0:
|
||||
return
|
||||
|
||||
core_rate_limit = github_client.get_rate_limit().core
|
||||
remaining = core_rate_limit.remaining
|
||||
try:
|
||||
remaining_int = int(remaining)
|
||||
except Exception:
|
||||
# If the remaining value is missing or non-numeric (e.g., mocked), skip the guard
|
||||
return
|
||||
|
||||
reset_at = core_rate_limit.reset.replace(tzinfo=timezone.utc)
|
||||
|
||||
if remaining_int <= threshold:
|
||||
raise RateLimitBudgetLow(remaining_int, threshold, reset_at)
|
||||
|
||||
|
||||
def sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
"""
|
||||
Sleep until the GitHub rate limit resets.
|
||||
|
||||
@@ -18,12 +18,17 @@ from github.RateLimit import RateLimit
|
||||
from github.Repository import Repository
|
||||
from github.Requester import Requester
|
||||
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.github.connector import GithubConnectorStage
|
||||
from onyx.connectors.github.connector import ITEMS_PER_PAGE
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.connectors.github.rate_limit_utils import raise_if_approaching_rate_limit
|
||||
from onyx.connectors.github.rate_limit_utils import RateLimitBudgetLow
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
from tests.unit.onyx.connectors.utils import (
|
||||
@@ -291,6 +296,112 @@ def test_load_from_checkpoint_with_rate_limit(
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_raise_if_approaching_rate_limit_raises_when_under_threshold() -> None:
|
||||
github_client = MagicMock(spec=Github)
|
||||
rate_limit_reset = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rate_limit = MagicMock(spec=RateLimit)
|
||||
rate_limit.core = MagicMock(remaining=5, reset=rate_limit_reset)
|
||||
github_client.get_rate_limit.return_value = rate_limit
|
||||
|
||||
with pytest.raises(RateLimitBudgetLow) as excinfo:
|
||||
raise_if_approaching_rate_limit(github_client, minimum_remaining=10)
|
||||
|
||||
assert excinfo.value.remaining == 5
|
||||
assert excinfo.value.threshold == 10
|
||||
assert excinfo.value.reset_at == rate_limit_reset
|
||||
|
||||
|
||||
def test_raise_if_approaching_rate_limit_allows_when_above_threshold() -> None:
|
||||
github_client = MagicMock(spec=Github)
|
||||
rate_limit_reset = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
rate_limit = MagicMock(spec=RateLimit)
|
||||
rate_limit.core = MagicMock(remaining=50, reset=rate_limit_reset)
|
||||
github_client.get_rate_limit.return_value = rate_limit
|
||||
|
||||
# Should not raise
|
||||
raise_if_approaching_rate_limit(github_client, minimum_remaining=10)
|
||||
|
||||
|
||||
def test_fetch_from_github_returns_checkpoint_when_rate_limit_low_prs(
|
||||
build_github_connector: Callable[..., GithubConnector],
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
github_connector = build_github_connector()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_repo = create_mock_repo()
|
||||
|
||||
checkpoint = github_connector.build_dummy_checkpoint()
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=mock_repo.id, headers=mock_repo.raw_headers, raw_data=mock_repo.raw_data
|
||||
)
|
||||
checkpoint.cached_repo_ids = []
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
|
||||
rate_limit_exception = RateLimitBudgetLow(
|
||||
remaining=1, threshold=10, reset_at=datetime.now(timezone.utc)
|
||||
)
|
||||
with (
|
||||
patch.object(SerializedRepository, "to_Repository", return_value=mock_repo),
|
||||
patch(
|
||||
"onyx.connectors.github.connector._get_batch_rate_limited",
|
||||
side_effect=rate_limit_exception,
|
||||
) as mock_batch,
|
||||
):
|
||||
gen = github_connector._fetch_from_github(checkpoint)
|
||||
with pytest.raises(StopIteration) as stop_exc:
|
||||
next(gen)
|
||||
|
||||
returned_checkpoint = stop_exc.value.value
|
||||
|
||||
mock_batch.assert_called_once()
|
||||
assert returned_checkpoint.stage == GithubConnectorStage.PRS
|
||||
assert returned_checkpoint.curr_page == 0
|
||||
assert returned_checkpoint.num_retrieved == 0
|
||||
assert returned_checkpoint.cached_repo is not None
|
||||
assert returned_checkpoint.cached_repo.id == mock_repo.id
|
||||
|
||||
|
||||
def test_fetch_from_github_returns_checkpoint_when_rate_limit_low_issues(
|
||||
build_github_connector: Callable[..., GithubConnector],
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
github_connector = build_github_connector()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_repo = create_mock_repo()
|
||||
|
||||
checkpoint = github_connector.build_dummy_checkpoint()
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=mock_repo.id, headers=mock_repo.raw_headers, raw_data=mock_repo.raw_data
|
||||
)
|
||||
checkpoint.cached_repo_ids = []
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
|
||||
rate_limit_exception = RateLimitBudgetLow(
|
||||
remaining=1, threshold=10, reset_at=datetime.now(timezone.utc)
|
||||
)
|
||||
with (
|
||||
patch.object(SerializedRepository, "to_Repository", return_value=mock_repo),
|
||||
patch(
|
||||
"onyx.connectors.github.connector._get_batch_rate_limited",
|
||||
side_effect=rate_limit_exception,
|
||||
) as mock_batch,
|
||||
):
|
||||
gen = github_connector._fetch_from_github(checkpoint)
|
||||
with pytest.raises(StopIteration) as stop_exc:
|
||||
next(gen)
|
||||
|
||||
returned_checkpoint = stop_exc.value.value
|
||||
|
||||
mock_batch.assert_called_once()
|
||||
assert returned_checkpoint.stage == GithubConnectorStage.ISSUES
|
||||
assert returned_checkpoint.curr_page == 0
|
||||
assert returned_checkpoint.num_retrieved == 0
|
||||
assert returned_checkpoint.cached_repo is not None
|
||||
assert returned_checkpoint.cached_repo.id == mock_repo.id
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_repo(
|
||||
build_github_connector: Callable[..., GithubConnector],
|
||||
mock_github_client: MagicMock,
|
||||
@@ -927,3 +1038,73 @@ def test_load_from_checkpoint_cursor_pagination_completion(
|
||||
assert (
|
||||
pull_requests_func_invocation_count == 3
|
||||
) # twice for repo2 PRs, once for repo1 PRs
|
||||
|
||||
|
||||
@patch("onyx.connectors.github.connector.Github")
|
||||
def test_load_credentials_with_pat(
|
||||
mock_github: MagicMock, build_github_connector: Callable[..., GithubConnector]
|
||||
) -> None:
|
||||
connector = build_github_connector()
|
||||
|
||||
connector.load_credentials({"github_access_token": "pat-token"})
|
||||
|
||||
mock_github.assert_called_once()
|
||||
args, kwargs = mock_github.call_args
|
||||
assert args[0] == "pat-token"
|
||||
assert kwargs.get("per_page") == ITEMS_PER_PAGE
|
||||
|
||||
|
||||
@patch("onyx.connectors.github.connector.GithubIntegration")
|
||||
@patch("onyx.connectors.github.connector.Github")
|
||||
def test_load_credentials_with_github_app_credentials(
|
||||
mock_github: MagicMock,
|
||||
mock_integration: MagicMock,
|
||||
build_github_connector: Callable[..., GithubConnector],
|
||||
) -> None:
|
||||
connector = build_github_connector()
|
||||
|
||||
integration_instance = MagicMock()
|
||||
integration_instance.get_access_token.return_value = MagicMock(token="app-token")
|
||||
mock_integration.return_value = integration_instance
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"github_app_id": "123",
|
||||
"github_app_installation_id": "456",
|
||||
"github_app_private_key": "PRIVATEKEY",
|
||||
}
|
||||
)
|
||||
|
||||
mock_integration.assert_called_once()
|
||||
integration_args, integration_kwargs = mock_integration.call_args
|
||||
assert integration_args[0] == 123
|
||||
assert integration_args[1] == "PRIVATEKEY"
|
||||
assert integration_kwargs.get("base_url") == GITHUB_CONNECTOR_BASE_URL
|
||||
|
||||
integration_instance.get_access_token.assert_called_once()
|
||||
access_token_args, _ = integration_instance.get_access_token.call_args
|
||||
assert access_token_args[0] == 456
|
||||
|
||||
mock_github.assert_called_once()
|
||||
args, kwargs = mock_github.call_args
|
||||
assert args[0] == "app-token"
|
||||
assert kwargs.get("per_page") == ITEMS_PER_PAGE
|
||||
assert kwargs.get("base_url") == GITHUB_CONNECTOR_BASE_URL
|
||||
|
||||
|
||||
def test_load_credentials_missing_credentials(
|
||||
build_github_connector: Callable[..., GithubConnector],
|
||||
) -> None:
|
||||
connector = build_github_connector()
|
||||
with pytest.raises(ConnectorMissingCredentialError):
|
||||
connector.load_credentials({})
|
||||
|
||||
|
||||
def test_load_credentials_partial_github_app_credentials(
|
||||
build_github_connector: Callable[..., GithubConnector],
|
||||
) -> None:
|
||||
connector = build_github_connector()
|
||||
with pytest.raises(ConnectorMissingCredentialError):
|
||||
connector.load_credentials(
|
||||
{"github_app_id": "123", "github_app_installation_id": "456"}
|
||||
)
|
||||
|
||||
@@ -45,7 +45,11 @@ export interface Credential<T> extends CredentialBase<T> {
|
||||
time_updated: string;
|
||||
}
|
||||
export interface GithubCredentialJson {
|
||||
github_access_token: string;
|
||||
github_access_token?: string;
|
||||
github_app_id?: string;
|
||||
github_app_installation_id?: string;
|
||||
github_app_private_key?: string;
|
||||
authentication_method?: string;
|
||||
}
|
||||
|
||||
export interface GitbookCredentialJson {
|
||||
@@ -278,7 +282,28 @@ export interface TestRailCredentialJson {
|
||||
}
|
||||
|
||||
export const credentialTemplates: Record<ValidSources, any> = {
|
||||
github: { github_access_token: "" } as GithubCredentialJson,
|
||||
github: {
|
||||
authentication_method: "pat",
|
||||
authMethods: [
|
||||
{
|
||||
value: "pat",
|
||||
label: "Personal Access Token",
|
||||
fields: { github_access_token: "" },
|
||||
description: "Use a classic or fine-grained PAT with repo access.",
|
||||
},
|
||||
{
|
||||
value: "app",
|
||||
label: "GitHub App Installation",
|
||||
fields: {
|
||||
github_app_id: "",
|
||||
github_app_installation_id: "",
|
||||
github_app_private_key: "",
|
||||
},
|
||||
description:
|
||||
"Use a GitHub App installation token (higher rate limits per installation). Provide the App ID, installation ID, and PEM private key.",
|
||||
},
|
||||
],
|
||||
} as CredentialTemplateWithAuth<GithubCredentialJson>,
|
||||
gitlab: {
|
||||
gitlab_url: "",
|
||||
gitlab_access_token: "",
|
||||
@@ -485,6 +510,9 @@ export const credentialTemplates: Record<ValidSources, any> = {
|
||||
export const credentialDisplayNames: Record<string, string> = {
|
||||
// Github
|
||||
github_access_token: "GitHub Access Token",
|
||||
github_app_id: "GitHub App ID",
|
||||
github_app_installation_id: "GitHub App Installation ID",
|
||||
github_app_private_key: "GitHub App Private Key (PEM)",
|
||||
|
||||
// Gitlab
|
||||
gitlab_url: "GitLab URL",
|
||||
|
||||
Reference in New Issue
Block a user