Compare commits

...

4 Commits

Author SHA1 Message Date
justin-tahara
2c57dc54cc fixing tests 2025-12-17 19:06:53 -08:00
justin-tahara
67fc1ce9e0 Adding additional guards 2025-12-17 18:48:21 -08:00
justin-tahara
0af7f8c44e Addressing comments and fixing tests and mypy 2025-12-17 18:47:10 -08:00
justin-tahara
70ed680c83 feat(github): App Tokens 2025-12-17 18:36:13 -08:00
4 changed files with 365 additions and 29 deletions

View File

@@ -9,6 +9,7 @@ from typing import Any
from typing import cast
from github import Github
from github import GithubIntegration
from github import RateLimitExceededException
from github import Repository
from github.GithubException import GithubException
@@ -28,6 +29,8 @@ from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.github.models import SerializedRepository
from onyx.connectors.github.rate_limit_utils import raise_if_approaching_rate_limit
from onyx.connectors.github.rate_limit_utils import RateLimitBudgetLow
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
from onyx.connectors.github.utils import deserialize_repository
from onyx.connectors.github.utils import get_external_access_permission
@@ -172,6 +175,7 @@ def _get_batch_rate_limited(
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
)
raise_if_approaching_rate_limit(github_client)
try:
if cursor_url:
# when this is set, we are resuming from an earlier
@@ -418,18 +422,66 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
self.github_client = (
Github(
credentials["github_access_token"],
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
)
self.github_client = self._build_github_client(credentials)
return None
def _build_github_client(self, credentials: dict[str, Any]) -> Github:
"""
Build a Github client from either a PAT or a GitHub App installation.
"""
github_access_token = credentials.get("github_access_token")
if github_access_token:
return (
Github(
github_access_token,
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(github_access_token, per_page=ITEMS_PER_PAGE)
)
app_id = credentials.get("github_app_id")
installation_id = credentials.get("github_app_installation_id")
private_key = credentials.get("github_app_private_key")
# Require all GitHub App fields if using that auth path
if app_id or installation_id or private_key:
if not (app_id and installation_id and private_key):
raise ConnectorMissingCredentialError(
"GitHub App authentication requires app_id, installation_id, and private_key."
)
try:
app_id_int = int(app_id)
installation_id_int = int(installation_id)
except (TypeError, ValueError):
raise ConnectorMissingCredentialError(
"GitHub App credentials must include numeric app and installation IDs."
)
integration = (
GithubIntegration(
app_id_int, private_key, base_url=GITHUB_CONNECTOR_BASE_URL
)
if GITHUB_CONNECTOR_BASE_URL
else GithubIntegration(app_id_int, private_key)
)
app_token = integration.get_access_token(installation_id_int).token
return (
Github(
app_token,
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(app_token, per_page=ITEMS_PER_PAGE)
)
raise ConnectorMissingCredentialError(
"GitHub credentials not loaded. Provide a PAT or GitHub App credentials."
)
def get_github_repo(
self, github_client: Github, attempt_num: int = 0
) -> Repository.Repository:
@@ -567,14 +619,22 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pr_batch = _get_batch_rate_limited(
self._pull_requests_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
try:
pr_batch = _get_batch_rate_limited(
self._pull_requests_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
except RateLimitBudgetLow as e:
logger.info(
"Stopping GitHub fetch early to avoid hitting rate limit "
f"(remaining={e.remaining}, threshold={e.threshold}, "
f"resets_at={e.reset_at.isoformat()}, seconds_until_reset={e.seconds_until_reset:.0f})."
)
return checkpoint
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
done_with_prs = False
num_prs = 0
@@ -640,16 +700,24 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issue_batch = list(
_get_batch_rate_limited(
self._issues_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
try:
issue_batch = list(
_get_batch_rate_limited(
self._issues_func(repo),
checkpoint.curr_page,
checkpoint.cursor_url,
checkpoint.num_retrieved,
cursor_url_callback,
self.github_client,
)
)
)
except RateLimitBudgetLow as e:
logger.info(
"Stopping GitHub fetch early to avoid hitting rate limit "
f"(remaining={e.remaining}, threshold={e.threshold}, "
f"resets_at={e.reset_at.isoformat()}, seconds_until_reset={e.seconds_until_reset:.0f})."
)
return checkpoint
logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
checkpoint.curr_page += 1
done_with_issues = False

View File

@@ -1,3 +1,4 @@
import os
import time
from datetime import datetime
from datetime import timedelta
@@ -10,6 +11,64 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _load_minimum_remaining_threshold() -> int:
"""Read the configurable remaining-budget threshold from env."""
default_threshold = 100
try:
return max(
int(
os.environ.get("GITHUB_RATE_LIMIT_MINIMUM_REMAINING", default_threshold)
),
0,
)
except ValueError:
return default_threshold
MINIMUM_RATE_LIMIT_REMAINING = _load_minimum_remaining_threshold()
class RateLimitBudgetLow(Exception):
"""Raised when we're close enough to the rate limit that we should pause early."""
def __init__(
self, remaining: int, threshold: int, reset_at: datetime, *args: object
) -> None:
super().__init__(*args)
self.remaining = remaining
self.threshold = threshold
self.reset_at = reset_at
self.seconds_until_reset = max(
0, (reset_at - datetime.now(tz=timezone.utc)).total_seconds()
)
def raise_if_approaching_rate_limit(
github_client: Github, minimum_remaining: int | None = None
) -> None:
"""Raise if the client is close to its rate limit to avoid long sleeps."""
threshold = (
minimum_remaining
if minimum_remaining is not None
else MINIMUM_RATE_LIMIT_REMAINING
)
if threshold <= 0:
return
core_rate_limit = github_client.get_rate_limit().core
remaining = core_rate_limit.remaining
try:
remaining_int = int(remaining)
except Exception:
# If the remaining value is missing or non-numeric (e.g., mocked), skip the guard
return
reset_at = core_rate_limit.reset.replace(tzinfo=timezone.utc)
if remaining_int <= threshold:
raise RateLimitBudgetLow(remaining_int, threshold, reset_at)
def sleep_after_rate_limit_exception(github_client: Github) -> None:
"""
Sleep until the GitHub rate limit resets.

View File

@@ -18,12 +18,17 @@ from github.RateLimit import RateLimit
from github.Repository import Repository
from github.Requester import Requester
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.github.connector import GithubConnectorStage
from onyx.connectors.github.connector import ITEMS_PER_PAGE
from onyx.connectors.github.models import SerializedRepository
from onyx.connectors.github.rate_limit_utils import raise_if_approaching_rate_limit
from onyx.connectors.github.rate_limit_utils import RateLimitBudgetLow
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
from tests.unit.onyx.connectors.utils import (
@@ -291,6 +296,112 @@ def test_load_from_checkpoint_with_rate_limit(
assert outputs[-1].next_checkpoint.has_more is False
def test_raise_if_approaching_rate_limit_raises_when_under_threshold() -> None:
github_client = MagicMock(spec=Github)
rate_limit_reset = datetime(2024, 1, 1, tzinfo=timezone.utc)
rate_limit = MagicMock(spec=RateLimit)
rate_limit.core = MagicMock(remaining=5, reset=rate_limit_reset)
github_client.get_rate_limit.return_value = rate_limit
with pytest.raises(RateLimitBudgetLow) as excinfo:
raise_if_approaching_rate_limit(github_client, minimum_remaining=10)
assert excinfo.value.remaining == 5
assert excinfo.value.threshold == 10
assert excinfo.value.reset_at == rate_limit_reset
def test_raise_if_approaching_rate_limit_allows_when_above_threshold() -> None:
github_client = MagicMock(spec=Github)
rate_limit_reset = datetime(2024, 1, 1, tzinfo=timezone.utc)
rate_limit = MagicMock(spec=RateLimit)
rate_limit.core = MagicMock(remaining=50, reset=rate_limit_reset)
github_client.get_rate_limit.return_value = rate_limit
# Should not raise
raise_if_approaching_rate_limit(github_client, minimum_remaining=10)
def test_fetch_from_github_returns_checkpoint_when_rate_limit_low_prs(
build_github_connector: Callable[..., GithubConnector],
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
) -> None:
github_connector = build_github_connector()
github_connector.github_client = mock_github_client
mock_repo = create_mock_repo()
checkpoint = github_connector.build_dummy_checkpoint()
checkpoint.cached_repo = SerializedRepository(
id=mock_repo.id, headers=mock_repo.raw_headers, raw_data=mock_repo.raw_data
)
checkpoint.cached_repo_ids = []
checkpoint.stage = GithubConnectorStage.PRS
rate_limit_exception = RateLimitBudgetLow(
remaining=1, threshold=10, reset_at=datetime.now(timezone.utc)
)
with (
patch.object(SerializedRepository, "to_Repository", return_value=mock_repo),
patch(
"onyx.connectors.github.connector._get_batch_rate_limited",
side_effect=rate_limit_exception,
) as mock_batch,
):
gen = github_connector._fetch_from_github(checkpoint)
with pytest.raises(StopIteration) as stop_exc:
next(gen)
returned_checkpoint = stop_exc.value.value
mock_batch.assert_called_once()
assert returned_checkpoint.stage == GithubConnectorStage.PRS
assert returned_checkpoint.curr_page == 0
assert returned_checkpoint.num_retrieved == 0
assert returned_checkpoint.cached_repo is not None
assert returned_checkpoint.cached_repo.id == mock_repo.id
def test_fetch_from_github_returns_checkpoint_when_rate_limit_low_issues(
build_github_connector: Callable[..., GithubConnector],
mock_github_client: MagicMock,
create_mock_repo: Callable[..., MagicMock],
) -> None:
github_connector = build_github_connector()
github_connector.github_client = mock_github_client
mock_repo = create_mock_repo()
checkpoint = github_connector.build_dummy_checkpoint()
checkpoint.cached_repo = SerializedRepository(
id=mock_repo.id, headers=mock_repo.raw_headers, raw_data=mock_repo.raw_data
)
checkpoint.cached_repo_ids = []
checkpoint.stage = GithubConnectorStage.ISSUES
rate_limit_exception = RateLimitBudgetLow(
remaining=1, threshold=10, reset_at=datetime.now(timezone.utc)
)
with (
patch.object(SerializedRepository, "to_Repository", return_value=mock_repo),
patch(
"onyx.connectors.github.connector._get_batch_rate_limited",
side_effect=rate_limit_exception,
) as mock_batch,
):
gen = github_connector._fetch_from_github(checkpoint)
with pytest.raises(StopIteration) as stop_exc:
next(gen)
returned_checkpoint = stop_exc.value.value
mock_batch.assert_called_once()
assert returned_checkpoint.stage == GithubConnectorStage.ISSUES
assert returned_checkpoint.curr_page == 0
assert returned_checkpoint.num_retrieved == 0
assert returned_checkpoint.cached_repo is not None
assert returned_checkpoint.cached_repo.id == mock_repo.id
def test_load_from_checkpoint_with_empty_repo(
build_github_connector: Callable[..., GithubConnector],
mock_github_client: MagicMock,
@@ -927,3 +1038,73 @@ def test_load_from_checkpoint_cursor_pagination_completion(
assert (
pull_requests_func_invocation_count == 3
) # twice for repo2 PRs, once for repo1 PRs
@patch("onyx.connectors.github.connector.Github")
def test_load_credentials_with_pat(
mock_github: MagicMock, build_github_connector: Callable[..., GithubConnector]
) -> None:
connector = build_github_connector()
connector.load_credentials({"github_access_token": "pat-token"})
mock_github.assert_called_once()
args, kwargs = mock_github.call_args
assert args[0] == "pat-token"
assert kwargs.get("per_page") == ITEMS_PER_PAGE
@patch("onyx.connectors.github.connector.GithubIntegration")
@patch("onyx.connectors.github.connector.Github")
def test_load_credentials_with_github_app_credentials(
mock_github: MagicMock,
mock_integration: MagicMock,
build_github_connector: Callable[..., GithubConnector],
) -> None:
connector = build_github_connector()
integration_instance = MagicMock()
integration_instance.get_access_token.return_value = MagicMock(token="app-token")
mock_integration.return_value = integration_instance
connector.load_credentials(
{
"github_app_id": "123",
"github_app_installation_id": "456",
"github_app_private_key": "PRIVATEKEY",
}
)
mock_integration.assert_called_once()
integration_args, integration_kwargs = mock_integration.call_args
assert integration_args[0] == 123
assert integration_args[1] == "PRIVATEKEY"
assert integration_kwargs.get("base_url") == GITHUB_CONNECTOR_BASE_URL
integration_instance.get_access_token.assert_called_once()
access_token_args, _ = integration_instance.get_access_token.call_args
assert access_token_args[0] == 456
mock_github.assert_called_once()
args, kwargs = mock_github.call_args
assert args[0] == "app-token"
assert kwargs.get("per_page") == ITEMS_PER_PAGE
assert kwargs.get("base_url") == GITHUB_CONNECTOR_BASE_URL
def test_load_credentials_missing_credentials(
build_github_connector: Callable[..., GithubConnector],
) -> None:
connector = build_github_connector()
with pytest.raises(ConnectorMissingCredentialError):
connector.load_credentials({})
def test_load_credentials_partial_github_app_credentials(
build_github_connector: Callable[..., GithubConnector],
) -> None:
connector = build_github_connector()
with pytest.raises(ConnectorMissingCredentialError):
connector.load_credentials(
{"github_app_id": "123", "github_app_installation_id": "456"}
)

View File

@@ -45,7 +45,11 @@ export interface Credential<T> extends CredentialBase<T> {
time_updated: string;
}
export interface GithubCredentialJson {
github_access_token: string;
github_access_token?: string;
github_app_id?: string;
github_app_installation_id?: string;
github_app_private_key?: string;
authentication_method?: string;
}
export interface GitbookCredentialJson {
@@ -278,7 +282,28 @@ export interface TestRailCredentialJson {
}
export const credentialTemplates: Record<ValidSources, any> = {
github: { github_access_token: "" } as GithubCredentialJson,
github: {
authentication_method: "pat",
authMethods: [
{
value: "pat",
label: "Personal Access Token",
fields: { github_access_token: "" },
description: "Use a classic or fine-grained PAT with repo access.",
},
{
value: "app",
label: "GitHub App Installation",
fields: {
github_app_id: "",
github_app_installation_id: "",
github_app_private_key: "",
},
description:
"Use a GitHub App installation token (higher rate limits per installation). Provide the App ID, installation ID, and PEM private key.",
},
],
} as CredentialTemplateWithAuth<GithubCredentialJson>,
gitlab: {
gitlab_url: "",
gitlab_access_token: "",
@@ -485,6 +510,9 @@ export const credentialTemplates: Record<ValidSources, any> = {
export const credentialDisplayNames: Record<string, string> = {
// Github
github_access_token: "GitHub Access Token",
github_app_id: "GitHub App ID",
github_app_installation_id: "GitHub App Installation ID",
github_app_private_key: "GitHub App Private Key (PEM)",
// Gitlab
gitlab_url: "GitLab URL",