Compare commits

...

7 Commits

Author SHA1 Message Date
pablonyx
69638b4c4e misc typing 2025-03-15 11:21:26 -07:00
pablonyx
8821f399f0 add brief comment 2025-03-15 10:44:43 -07:00
pablonyx
cebc341991 quick nit 2025-03-13 16:11:32 -07:00
pablonyx
81cb98aaa7 k 2025-03-13 16:09:53 -07:00
pablonyx
38afc8fa3a clean up + tests 2025-03-13 16:07:37 -07:00
pablonyx
185aa07526 remove dupes 2025-03-13 15:47:33 -07:00
pablonyx
3ba554843c Auto refresh credentials 2025-03-13 10:22:27 -07:00
11 changed files with 842 additions and 13 deletions

View File

@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,

View File

@@ -0,0 +1,211 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -5,12 +5,16 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -687,16 +691,20 @@ cookie_transport = CookieTransport(
)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -755,6 +763,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -805,6 +882,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -1038,12 +1197,20 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(

View File

@@ -667,3 +667,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)

View File

@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
prefix="/auth",
)
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,
fastapi_users.get_refresh_router(auth_backend),
prefix="/auth",
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)

View File

@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
# just gets the version of Onyx (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/refresh", {"POST"}),
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),

View File

@@ -0,0 +1,43 @@
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import pytest
from onyx.db.models import OAuthAccount
from onyx.db.models import User
@pytest.fixture
def mock_user() -> MagicMock:
"""Creates a mock User instance for testing."""
user = MagicMock(spec=User)
user.email = "test@example.com"
user.id = "test-user-id"
return user
@pytest.fixture
def mock_oauth_account() -> MagicMock:
"""Creates a mock OAuthAccount instance for testing."""
oauth_account = MagicMock(spec=OAuthAccount)
oauth_account.oauth_name = "google"
oauth_account.refresh_token = "test-refresh-token"
oauth_account.access_token = "test-access-token"
oauth_account.expires_at = None
return oauth_account
@pytest.fixture
def mock_user_manager() -> MagicMock:
"""Creates a mock user manager for testing."""
user_manager = MagicMock()
user_manager.user_db = MagicMock()
user_manager.user_db.update_oauth_account = AsyncMock()
user_manager.user_db.update = AsyncMock()
return user_manager
@pytest.fixture
def mock_db_session() -> MagicMock:
"""Creates a mock database session for testing."""
return MagicMock()

View File

@@ -0,0 +1,273 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.oauth_refresher import _test_expire_oauth_token
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
from onyx.auth.oauth_refresher import refresh_oauth_token
from onyx.db.models import OAuthAccount
@pytest.mark.asyncio
async def test_refresh_oauth_token_success(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test successful OAuth token refresh."""
# Mock HTTP client and response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new_token",
"refresh_token": "new_refresh_token",
"expires_in": 3600,
}
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Use fixture values but ensure refresh token exists
mock_oauth_account.oauth_name = (
"google" # Ensure it's google to match the refresh endpoint
)
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_instance = mock_client
client_class_mock.return_value.__aenter__.return_value = client_instance
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is True
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify token data was updated correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert update_data["access_token"] == "new_token"
assert update_data["refresh_token"] == "new_refresh_token"
assert "expires_at" in update_data
@pytest.mark.asyncio
async def test_refresh_oauth_token_failure(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> bool:
"""Test OAuth token refresh failure due to HTTP error."""
# Mock HTTP client with error response
mock_response = MagicMock()
mock_response.status_code = 400 # Simulate error
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Ensure refresh token exists and provider is supported
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_class_mock.return_value.__aenter__.return_value = mock_client
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_not_called()
return True
@pytest.mark.asyncio
async def test_refresh_oauth_token_no_refresh_token(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test OAuth token refresh when no refresh token is available."""
# Set refresh token to None
mock_oauth_account.refresh_token = None
mock_oauth_account.oauth_name = "google"
# No need to mock httpx since it shouldn't be called
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
@pytest.mark.asyncio
async def test_check_and_refresh_oauth_tokens(
mock_user: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test checking and refreshing multiple OAuth tokens."""
# Create mock user with OAuth accounts
now_timestamp = datetime.now(timezone.utc).timestamp()
# Create an account that needs refreshing (expiring soon)
expiring_account = MagicMock(spec=OAuthAccount)
expiring_account.oauth_name = "google"
expiring_account.refresh_token = "refresh_token_1"
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
# Create an account that doesn't need refreshing (expires later)
valid_account = MagicMock(spec=OAuthAccount)
valid_account.oauth_name = "google"
valid_account.refresh_token = "refresh_token_2"
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
# Create an account without a refresh token
no_refresh_account = MagicMock(spec=OAuthAccount)
no_refresh_account.oauth_name = "google"
no_refresh_account.refresh_token = None
no_refresh_account.expires_at = (
now_timestamp + 60
) # Expiring soon but no refresh token
# Set oauth_accounts on the mock user
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
# Mock refresh_oauth_token function
with patch(
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
) as mock_refresh:
# Call the function under test
await check_and_refresh_oauth_tokens(
mock_user, mock_db_session, mock_user_manager
)
# Assertions
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
# Check it was called with the expiring account
mock_refresh.assert_called_once_with(
mock_user, expiring_account, mock_db_session, mock_user_manager
)
@pytest.mark.asyncio
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
"""Test identifying OAuth accounts that need refresh tokens."""
# Create accounts with and without refresh tokens
account_with_token = MagicMock(spec=OAuthAccount)
account_with_token.oauth_name = "google"
account_with_token.refresh_token = "refresh_token"
account_without_token = MagicMock(spec=OAuthAccount)
account_without_token.oauth_name = "google"
account_without_token.refresh_token = None
second_account_without_token = MagicMock(spec=OAuthAccount)
second_account_without_token.oauth_name = "github"
second_account_without_token.refresh_token = (
"" # Empty string should also be treated as missing
)
# Set accounts on user
mock_user.oauth_accounts = [
account_with_token,
account_without_token,
second_account_without_token,
]
# Call the function under test
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
mock_user
)
# Assertions
assert len(accounts_needing_refresh) == 2
assert account_without_token in accounts_needing_refresh
assert second_account_without_token in accounts_needing_refresh
assert account_with_token not in accounts_needing_refresh
@pytest.mark.asyncio
async def test_check_oauth_account_has_refresh_token(
mock_user: MagicMock, mock_oauth_account: MagicMock
) -> None:
"""Test checking if an OAuth account has a refresh token."""
# Test with refresh token
mock_oauth_account.refresh_token = "refresh_token"
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is True
# Test with None refresh token
mock_oauth_account.refresh_token = None
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
# Test with empty string refresh token
mock_oauth_account.refresh_token = ""
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
@pytest.mark.asyncio
async def test_test_expire_oauth_token(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test the testing utility function for token expiration."""
# Set up the mock account
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "test_refresh_token"
mock_oauth_account.access_token = "test_access_token"
# Call the function under test
result = await _test_expire_oauth_token(
mock_user,
mock_oauth_account,
mock_db_session,
mock_user_manager,
expire_in_seconds=10,
)
# Assertions
assert result is True
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify the expiration time was set correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert "expires_at" in update_data
# Now should be within 10-11 seconds of the set expiration
now = datetime.now(timezone.utc).timestamp()
assert update_data["expires_at"] - now >= 9 # Allow 1 second for test execution
assert update_data["expires_at"] - now <= 11 # Allow 1 second for test execution

View File

@@ -1,11 +1,18 @@
"use client";
import React, { createContext, useContext, useState, useEffect } from "react";
import React, {
createContext,
useContext,
useState,
useEffect,
useRef,
} from "react";
import { User, UserRole } from "@/lib/types";
import { getCurrentUser } from "@/lib/user";
import { usePostHog } from "posthog-js/react";
import { CombinedSettings } from "@/app/admin/settings/interfaces";
import { SettingsContext } from "../settings/SettingsProvider";
import { useTokenRefresh } from "@/hooks/useTokenRefresh";
interface UserContextType {
user: User | null;
@@ -93,6 +100,10 @@ export function UserProvider({
console.error("Error fetching current user:", error);
}
};
// Use the custom token refresh hook
useTokenRefresh(upToDateUser, fetchUser);
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
try {
setUpToDateUser((prevUser) => {

View File

@@ -0,0 +1,84 @@
import { useState, useEffect, useRef } from "react";
import { User } from "@/lib/types";
import { NO_AUTH_USER_ID } from "@/lib/extension/constants";
// Refresh token every 10 minutes (600000ms)
// This is shorter than the session expiry time to ensure tokens stay valid
const REFRESH_INTERVAL = 600000;
// Custom hook for handling JWT token refresh for current user
export function useTokenRefresh(
user: User | null,
onRefreshFail: () => Promise<void>
) {
// Track last refresh time to avoid unnecessary calls
const [lastTokenRefresh, setLastTokenRefresh] = useState<number>(Date.now());
// Use a ref to track first load
const isFirstLoad = useRef(true);
useEffect(() => {
if (!user || user.id === NO_AUTH_USER_ID) return;
const refreshTokenPeriodically = async () => {
try {
// Skip time check if this is first load - we always refresh on first load
const isTimeToRefresh =
isFirstLoad.current ||
Date.now() - lastTokenRefresh > REFRESH_INTERVAL - 60000;
if (!isTimeToRefresh) {
return;
}
// Reset first load flag
if (isFirstLoad.current) {
isFirstLoad.current = false;
}
const response = await fetch("/api/auth/refresh", {
method: "POST",
credentials: "include",
});
if (response.ok) {
// Update last refresh time on success
setLastTokenRefresh(Date.now());
console.debug("Auth token refreshed successfully");
} else {
console.warn("Failed to refresh auth token:", response.status);
// If token refresh fails, try to get current user info
await onRefreshFail();
}
} catch (error) {
console.error("Error refreshing auth token:", error);
}
};
// Always attempt to refresh on first component mount
// This helps ensure tokens are fresh, especially after browser refresh
refreshTokenPeriodically();
// Set up interval for periodic refreshes
const intervalId = setInterval(refreshTokenPeriodically, REFRESH_INTERVAL);
// Also refresh token on window focus, but no more than once per minute
const handleVisibilityChange = () => {
if (
document.visibilityState === "visible" &&
Date.now() - lastTokenRefresh > 60000
) {
refreshTokenPeriodically();
}
};
document.addEventListener("visibilitychange", handleVisibilityChange);
return () => {
clearInterval(intervalId);
document.removeEventListener("visibilitychange", handleVisibilityChange);
};
}, [user, lastTokenRefresh, onRefreshFail]);
return { lastTokenRefresh };
}

View File

@@ -35,3 +35,5 @@ export const LocalStorageKeys = {
export const SEARCH_PARAMS = {
DEFAULT_SIDEBAR_OFF: "defaultSidebarOff",
};
export const NO_AUTH_USER_ID = "__no_auth_user__";