feat(billing): fetch Stripe publishable key from S3 (#7595)

This commit is contained in:
Nikolas Garza
2026-01-20 17:32:57 -08:00
committed by GitHub
parent 8a2e4ed36f
commit a8db236e37
7 changed files with 262 additions and 13 deletions

View File

@@ -10,6 +10,8 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
("/enterprise-settings/logo", {"GET"}),
("/enterprise-settings/logotype", {"GET"}),
("/enterprise-settings/custom-analytics-script", {"GET"}),
# Stripe publishable key is safe to expose publicly
("/tenants/stripe-publishable-key", {"GET"}),
]

View File

@@ -1,3 +1,6 @@
import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@@ -12,11 +15,14 @@ from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -26,6 +32,10 @@ logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Cache for Stripe publishable key to avoid hitting S3 on every request
_stripe_publishable_key_cache: str | None = None
_stripe_key_lock = asyncio.Lock()
@router.post("/product-gating")
def gate_product(
@@ -113,3 +123,67 @@ async def create_subscription_session(
except Exception as e:
logger.exception("Failed to create subscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stripe-publishable-key")
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
"""
Fetch the Stripe publishable key.
Priority: env var override (for testing) > S3 bucket (production).
This endpoint is public (no auth required) since publishable keys are safe to expose.
The key is cached in memory to avoid hitting S3 on every request.
"""
global _stripe_publishable_key_cache
# Fast path: return cached value without lock
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Use lock to prevent concurrent S3 requests
async with _stripe_key_lock:
# Double-check after acquiring lock (another request may have populated cache)
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Check for env var override first (for local testing with pk_test_* keys)
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
response.raise_for_status()
key = response.text.strip()
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -105,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
class ApproveUserRequest(BaseModel):
email: str
class StripePublishableKeyResponse(BaseModel):
publishable_key: str

View File

@@ -1027,3 +1027,14 @@ INSTANCE_TYPE = (
## Discord Bot Configuration
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
## Stripe Configuration
# URL to fetch the Stripe publishable key from a public S3 bucket.
# Publishable keys are safe to expose publicly - they can only initialize
# Stripe.js and tokenize payment info, not make charges or access data.
STRIPE_PUBLISHABLE_KEY_URL = (
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
)
# Override for local testing with Stripe test keys (pk_test_*)
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")

View File

@@ -0,0 +1,158 @@
"""Tests for billing API endpoints."""
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
class TestGetStripePublishableKey:
"""Tests for get_stripe_publishable_key endpoint."""
def setup_method(self) -> None:
"""Reset the cache before each test."""
import ee.onyx.server.tenants.billing_api as billing_api
billing_api._stripe_publishable_key_cache = None
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL",
"https://example.com/key.txt",
)
async def test_fetches_from_s3_when_no_override(self) -> None:
"""Should fetch key from S3 when no env var override is set."""
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
mock_response = MagicMock()
mock_response.text = "pk_live_test123"
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
result = await get_stripe_publishable_key()
assert result.publishable_key == "pk_live_test123"
@pytest.mark.asyncio
@patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
"pk_test_override123",
)
async def test_uses_env_var_override_when_set(self) -> None:
"""Should use env var override instead of fetching from S3."""
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
with patch("httpx.AsyncClient") as mock_client:
result = await get_stripe_publishable_key()
# Should not call S3
mock_client.assert_not_called()
assert result.publishable_key == "pk_test_override123"
@pytest.mark.asyncio
@patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
"invalid_key",
)
async def test_rejects_invalid_env_var_key_format(self) -> None:
"""Should reject keys that don't start with pk_."""
from fastapi import HTTPException
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
with pytest.raises(HTTPException) as exc_info:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "Invalid Stripe publishable key format" in exc_info.value.detail
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL",
"https://example.com/key.txt",
)
async def test_rejects_invalid_s3_key_format(self) -> None:
"""Should reject keys from S3 that don't start with pk_."""
from fastapi import HTTPException
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
mock_response = MagicMock()
mock_response.text = "invalid_key"
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
with pytest.raises(HTTPException) as exc_info:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "Invalid Stripe publishable key format" in exc_info.value.detail
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL",
"https://example.com/key.txt",
)
async def test_handles_s3_fetch_error(self) -> None:
"""Should return error when S3 fetch fails."""
from fastapi import HTTPException
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
side_effect=httpx.HTTPError("Connection failed")
)
with pytest.raises(HTTPException) as exc_info:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "Failed to fetch Stripe publishable key" in exc_info.value.detail
@pytest.mark.asyncio
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL", None)
async def test_error_when_no_config(self) -> None:
"""Should return error when neither env var nor S3 URL is configured."""
from fastapi import HTTPException
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
with pytest.raises(HTTPException) as exc_info:
await get_stripe_publishable_key()
assert exc_info.value.status_code == 500
assert "not configured" in exc_info.value.detail
@pytest.mark.asyncio
@patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
"pk_test_cached",
)
async def test_caches_key_after_first_fetch(self) -> None:
"""Should cache the key and return it on subsequent calls."""
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
# First call
result1 = await get_stripe_publishable_key()
assert result1.publishable_key == "pk_test_cached"
# Second call - should use cache even if we change the override
with patch(
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
"pk_test_different",
):
result2 = await get_stripe_publishable_key()
# Should still return cached value
assert result2.publishable_key == "pk_test_cached"

View File

@@ -7,14 +7,21 @@ import Button from "@/refresh-components/buttons/Button";
import InlineExternalLink from "@/refresh-components/InlineExternalLink";
import { logout } from "@/lib/user";
import { loadStripe } from "@stripe/stripe-js";
import {
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY,
NEXT_PUBLIC_CLOUD_ENABLED,
} from "@/lib/constants";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import Text from "@/refresh-components/texts/Text";
import { SvgLock } from "@opal/icons";
const linkClassName = "text-action-link-05 hover:text-action-link-06";
const fetchStripePublishableKey = async (): Promise<string> => {
const response = await fetch("/api/tenants/stripe-publishable-key");
if (!response.ok) {
throw new Error("Failed to fetch Stripe publishable key");
}
const data = await response.json();
return data.publishable_key;
};
const fetchResubscriptionSession = async () => {
const response = await fetch("/api/tenants/create-subscription-session", {
method: "POST",
@@ -35,14 +42,10 @@ export default function AccessRestricted() {
const handleResubscribe = async () => {
setIsLoading(true);
setError(null);
if (!NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY) {
setError("Stripe public key not found");
setIsLoading(false);
return;
}
try {
const publishableKey = await fetchStripePublishableKey();
const { sessionId } = await fetchResubscriptionSession();
const stripe = await loadStripe(NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY);
const stripe = await loadStripe(publishableKey);
if (stripe) {
await stripe.redirectToCheckout({ sessionId });

View File

@@ -93,9 +93,6 @@ export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK =
process.env.NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK?.toLowerCase() ===
"true";
export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY =
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;
// Restrict markdown links to safe protocols
export const ALLOWED_URL_PROTOCOLS = ["http:", "https:", "mailto:"] as const;