mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
feat(billing): fetch Stripe publishable key from S3 (#7595)
This commit is contained in:
@@ -10,6 +10,8 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
("/enterprise-settings/logotype", {"GET"}),
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@@ -12,11 +15,14 @@ from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -26,6 +32,10 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
@@ -113,3 +123,67 @@ async def create_subscription_session(
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""
|
||||
Fetch the Stripe publishable key.
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -105,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
publishable_key: str
|
||||
|
||||
@@ -1027,3 +1027,14 @@ INSTANCE_TYPE = (
|
||||
## Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
|
||||
## Stripe Configuration
|
||||
# URL to fetch the Stripe publishable key from a public S3 bucket.
|
||||
# Publishable keys are safe to expose publicly - they can only initialize
|
||||
# Stripe.js and tokenize payment info, not make charges or access data.
|
||||
STRIPE_PUBLISHABLE_KEY_URL = (
|
||||
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
|
||||
)
|
||||
# Override for local testing with Stripe test keys (pk_test_*)
|
||||
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
||||
|
||||
158
backend/tests/unit/ee/onyx/server/tenants/test_billing_api.py
Normal file
158
backend/tests/unit/ee/onyx/server/tenants/test_billing_api.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for billing API endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
class TestGetStripePublishableKey:
|
||||
"""Tests for get_stripe_publishable_key endpoint."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset the cache before each test."""
|
||||
import ee.onyx.server.tenants.billing_api as billing_api
|
||||
|
||||
billing_api._stripe_publishable_key_cache = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL",
|
||||
"https://example.com/key.txt",
|
||||
)
|
||||
async def test_fetches_from_s3_when_no_override(self) -> None:
|
||||
"""Should fetch key from S3 when no env var override is set."""
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "pk_live_test123"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
result = await get_stripe_publishable_key()
|
||||
|
||||
assert result.publishable_key == "pk_live_test123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
|
||||
"pk_test_override123",
|
||||
)
|
||||
async def test_uses_env_var_override_when_set(self) -> None:
|
||||
"""Should use env var override instead of fetching from S3."""
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
result = await get_stripe_publishable_key()
|
||||
# Should not call S3
|
||||
mock_client.assert_not_called()
|
||||
|
||||
assert result.publishable_key == "pk_test_override123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
|
||||
"invalid_key",
|
||||
)
|
||||
async def test_rejects_invalid_env_var_key_format(self) -> None:
|
||||
"""Should reject keys that don't start with pk_."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL",
|
||||
"https://example.com/key.txt",
|
||||
)
|
||||
async def test_rejects_invalid_s3_key_format(self) -> None:
|
||||
"""Should reject keys from S3 that don't start with pk_."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "invalid_key"
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Invalid Stripe publishable key format" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL",
|
||||
"https://example.com/key.txt",
|
||||
)
|
||||
async def test_handles_s3_fetch_error(self) -> None:
|
||||
"""Should return error when S3 fetch fails."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
side_effect=httpx.HTTPError("Connection failed")
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Failed to fetch Stripe publishable key" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE", None)
|
||||
@patch("ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_URL", None)
|
||||
async def test_error_when_no_config(self) -> None:
|
||||
"""Should return error when neither env var nor S3 URL is configured."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_stripe_publishable_key()
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "not configured" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
|
||||
"pk_test_cached",
|
||||
)
|
||||
async def test_caches_key_after_first_fetch(self) -> None:
|
||||
"""Should cache the key and return it on subsequent calls."""
|
||||
from ee.onyx.server.tenants.billing_api import get_stripe_publishable_key
|
||||
|
||||
# First call
|
||||
result1 = await get_stripe_publishable_key()
|
||||
assert result1.publishable_key == "pk_test_cached"
|
||||
|
||||
# Second call - should use cache even if we change the override
|
||||
with patch(
|
||||
"ee.onyx.server.tenants.billing_api.STRIPE_PUBLISHABLE_KEY_OVERRIDE",
|
||||
"pk_test_different",
|
||||
):
|
||||
result2 = await get_stripe_publishable_key()
|
||||
# Should still return cached value
|
||||
assert result2.publishable_key == "pk_test_cached"
|
||||
@@ -7,14 +7,21 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import InlineExternalLink from "@/refresh-components/InlineExternalLink";
|
||||
import { logout } from "@/lib/user";
|
||||
import { loadStripe } from "@stripe/stripe-js";
|
||||
import {
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY,
|
||||
NEXT_PUBLIC_CLOUD_ENABLED,
|
||||
} from "@/lib/constants";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgLock } from "@opal/icons";
|
||||
|
||||
const linkClassName = "text-action-link-05 hover:text-action-link-06";
|
||||
|
||||
const fetchStripePublishableKey = async (): Promise<string> => {
|
||||
const response = await fetch("/api/tenants/stripe-publishable-key");
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch Stripe publishable key");
|
||||
}
|
||||
const data = await response.json();
|
||||
return data.publishable_key;
|
||||
};
|
||||
|
||||
const fetchResubscriptionSession = async () => {
|
||||
const response = await fetch("/api/tenants/create-subscription-session", {
|
||||
method: "POST",
|
||||
@@ -35,14 +42,10 @@ export default function AccessRestricted() {
|
||||
const handleResubscribe = async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
if (!NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY) {
|
||||
setError("Stripe public key not found");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const publishableKey = await fetchStripePublishableKey();
|
||||
const { sessionId } = await fetchResubscriptionSession();
|
||||
const stripe = await loadStripe(NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY);
|
||||
const stripe = await loadStripe(publishableKey);
|
||||
|
||||
if (stripe) {
|
||||
await stripe.redirectToCheckout({ sessionId });
|
||||
|
||||
@@ -93,9 +93,6 @@ export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK =
|
||||
process.env.NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK?.toLowerCase() ===
|
||||
"true";
|
||||
|
||||
export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY =
|
||||
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;
|
||||
|
||||
// Restrict markdown links to safe protocols
|
||||
export const ALLOWED_URL_PROTOCOLS = ["http:", "https:", "mailto:"] as const;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user