Compare commits

..

27 Commits

Author SHA1 Message Date
pablonyx
12cbc7642f update 2025-04-04 11:03:03 -07:00
pablonyx
eec25fb51d fix build 2025-03-27 11:47:22 -07:00
pablonyx
b1af9f7f24 nit 2025-03-26 18:36:43 -07:00
pablonyx
96b1b66e09 nit 2025-03-26 18:24:31 -07:00
pablonyx
f16ab9b419 k 2025-03-26 18:09:12 -07:00
pablonyx
84f5e69826 nit 2025-03-26 17:25:53 -07:00
pablonyx
b493985b30 update 2025-03-26 16:47:06 -07:00
pablonyx
2db0017a95 update 2025-03-26 16:16:42 -07:00
evan-danswer
178a64f298 fix issue with drive connector service account indexing (#4356)
* fix issue with drive connector service account indexing

* correct checkpoint resumption

* final set of fixes

* nit

* fix typing

* logging and CW comments

* nit
2025-03-26 20:54:26 +00:00
pablonyx
c79f1edf1d add a flush (#4361) 2025-03-26 14:40:52 -07:00
pablonyx
7c8e23aa54 Fix saml conversion from ext_perm -> basic (#4343)
* fix saml conversion from ext_perm -> basic

* quick nit

* minor fix

* finalize

* update

* quick fix
2025-03-26 20:36:51 +00:00
pablonyx
d37b427d52 fix email flow (#4339) 2025-03-26 18:59:12 +00:00
pablonyx
a65fefd226 test fix 2025-03-26 12:43:38 -07:00
rkuo-danswer
bb09bde519 Bugfix/google drive size threshold 2 (#4355) 2025-03-26 12:06:36 -07:00
Tim Rosenblatt
0f6cf0fc58 Fixes docker logs helper text in run-nginx.sh (#3678)
The docker container name is slightly wrong, and this commit fixes it.
2025-03-26 09:03:35 -07:00
pablonyx
fed06b592d Auto refresh credentials (#4268)
* Auto refresh credentials

* remove dupes

* clean up + tests

* k

* quick nit

* add brief comment

* misc typing
2025-03-26 01:53:31 +00:00
pablonyx
8d92a1524e fix invitation on cloud (#4351)
* fix invitation on cloud

* k
2025-03-26 01:25:17 +00:00
pablonyx
ecfea9f5ed Email formatting devices (#4353)
* update email formatting

* k

* update

* k

* nit
2025-03-25 21:42:32 +00:00
rkuo-danswer
b269f1ba06 fix broken function call (#4354)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 21:07:31 +00:00
pablonyx
30c878efa5 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-03-25 18:39:55 +00:00
pablonyx
2024776c19 Respect contextvars when parallelizing for Google Drive (#4291)
* k

* k

* fix typing
2025-03-25 17:40:12 +00:00
pablonyx
431316929c k (#4336) 2025-03-25 17:00:35 +00:00
pablonyx
c5b9c6e308 update (#4344) 2025-03-25 16:56:23 +00:00
pablonyx
73dd188b3f update (#4338) 2025-03-25 16:55:25 +00:00
evan-danswer
79b061abbc Daylight savings time handling (#4345)
* confluence timezone improvements

* confluence timezone improvements
2025-03-25 16:11:30 +00:00
rkuo-danswer
552f1ead4f use correct namespace in redis for certain keys (#4340)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 04:10:31 +00:00
evan-danswer
17925b49e8 typing fix (#4342)
* typing fix

* changed type hint to help future coders
2025-03-25 01:01:13 +00:00
58 changed files with 6233 additions and 8471 deletions

View File

@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging

View File

@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,

View File

@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
# Define non-authenticated user roles that should be re-created during SAML login
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
return await user_manager.get_by_email(email)
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if user.role in NON_AUTHENTICATED_ROLES:
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:
logger.notice("Creating user from SAML login")
logger.info("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user: User = await user_manager.create(
user = await user_manager.create(
UserCreate(
email=email,
password=hashed_pass,
is_verified=True,
role=role,
)
)

View File

@@ -16,10 +16,10 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
@@ -62,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
}}
.header img {{
max-width: 140px;
width: 140px;
height: auto;
filter: brightness(1.1) contrast(1.2);
border-radius: 8px;
padding: 5px;
}}
.body-content {{
padding: 20px 30px;
@@ -78,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
padding: 14px 24px;
background-color: #0055FF;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
font-weight: 600;
font-size: 16px;
margin-top: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
text-align: center;
}}
.footer {{
font-size: 13px;
@@ -166,6 +175,7 @@ def send_email(
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
@@ -174,17 +184,30 @@ def send_email(
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
msg.attach(part_text)
msg.attach(part_html)
# Add text part first (lowest priority)
text_part = MIMEText(text_body, "plain")
msg.attach(text_part)
if inline_png:
# For HTML with images, create a multipart/related container
related = MIMEMultipart("related")
# Add the HTML part to the related container
html_part = MIMEText(html_body, "html")
related.attach(html_part)
# Add image with proper Content-ID to the related container
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", inline_png[0]) # CID reference
img.add_header("Content-ID", f"<{inline_png[0]}>")
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
msg.attach(img)
related.attach(img)
# Add the related part to the message (higher priority than text)
msg.attach(related)
else:
# No images, just add HTML directly (higher priority than text)
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -332,17 +355,23 @@ def send_forgot_password_email(
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
subject = f"Reset Your {application_name} Password"
heading = "Reset Your Password"
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
cta_text = "Reset Password"
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
html_content = build_html_email(
application_name,
"Reset Your Password",
heading,
message,
cta_text,
cta_link,
)
text_content = (
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
)
text_content = f"Click the following link to reset your password: {link}"
send_email(
user_email,
subject,
@@ -356,6 +385,7 @@ def send_forgot_password_email(
def send_user_verification_email(
user_email: str,
token: str,
new_organization: bool = False,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
@@ -372,6 +402,8 @@ def send_user_verification_email(
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
if new_organization:
link = add_url_params(link, {"first_user": "true"})
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)

View File

@@ -0,0 +1,211 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -5,12 +5,16 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -581,8 +585,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
user_count = await get_user_count()
send_user_verification_email(
user.email, token, new_organization=user_count == 1
)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
@@ -688,16 +694,20 @@ cookie_transport = CookieTransport(
)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -756,6 +766,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -806,6 +885,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -1039,12 +1200,20 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(

View File

@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")

View File

@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -235,7 +236,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -451,6 +451,8 @@ def monitor_connector_deletion_taskset(
credential_id=cc_pair.credential_id,
)
db_session.flush()
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,

View File

@@ -68,6 +68,8 @@ from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -875,6 +877,21 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"id": payload.id if payload else None,
"total_docs": initial if initial is not None else 0,
"remaining_docs": remaining,
"synced_docs": (initial - remaining) if initial is not None else 0,
"is_complete": remaining == 0,
},
tenant_id=tenant_id,
)
if remaining > 0:
return

View File

@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -570,6 +573,22 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_num": batch_num,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
@@ -585,6 +604,30 @@ def _run_indexing(
checkpoint=checkpoint,
)
# Add telemetry for completed indexing
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
redis_connector_index = redis_connector.new_index(
index_attempt_start.search_settings_id
)
final_progress = redis_connector_index.get_progress() or 0
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_count": batch_num,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
"redis_progress": final_progress,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "

View File

@@ -1,6 +1,8 @@
import json
import os
import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
@@ -383,10 +385,23 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
def get_current_tz_offset() -> int:
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
# remove tzinfo to compare non-timezone-aware objects.
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
return round(time_diff.total_seconds() / 3600)
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
# For the default value, we assume that the user's local timezone is more likely to be
# correct (i.e. the configured user's timezone or the default server one) than UTC.
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
@@ -677,3 +692,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)

View File

@@ -79,6 +79,8 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
]
)
ONE_HOUR = 3600
class ConfluenceConnector(
LoadConnector,
@@ -429,7 +431,17 @@ class ConfluenceConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
try:
return self._fetch_document_batches(start, end)
except Exception as e:
if "field 'updated' is invalid" in str(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
)
return self._fetch_document_batches(start - ONE_HOUR, end)
raise
def retrieve_all_slim_documents(
self,

View File

@@ -2,11 +2,11 @@ import copy
import threading
from collections.abc import Callable
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any
from typing import cast
from typing import Protocol
from urllib.parse import urlparse
@@ -58,13 +58,13 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
@@ -461,6 +461,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
DriveRetrievalStage.MY_DRIVE_FILES,
)
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
@@ -496,7 +497,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
yield from _yield_from_drive(drive_id, start)
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
def _yield_from_folder_crawl(
@@ -549,6 +550,16 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
)
# Setup initial completion map on first connector run
for email in all_org_emails:
# don't overwrite existing completion map on resuming runs
if email in checkpoint.completion_map:
continue
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
# we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
@@ -562,11 +573,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
for email in all_org_emails:
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -797,10 +803,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
return
for file in drive_files:
if file.error is not None:
if file.error is None:
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until_parent_id=file.parent_id,
)
yield file
@@ -902,118 +910,78 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[list[Document | ConnectorFailure]]:
) -> Iterator[Document | ConnectorFailure]:
try:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
def _yield_batch(
files_batch: list[GoogleDriveFileType],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
]
continue
files_batch.append(retrieved_file.drive_file)
docs_and_failures = [result for result in results if result is not None]
if len(files_batch) < self.batch_size:
continue
if docs_and_failures:
yield from docs_and_failures
batches_complete += 1
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
if documents:
yield documents
batches_complete += 1
files_batch = []
continue
files_batch.append(retrieved_file.drive_file)
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
if len(files_batch) < self.batch_size:
continue
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
yield from _yield_batch(files_batch)
files_batch = []
if documents:
yield documents
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
# Process any remaining files
if files_batch:
yield from _yield_batch(files_batch)
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@@ -1035,10 +1003,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
try:
for doc_list in self._extract_docs_from_google_drive(
checkpoint, start, end
):
yield from doc_list
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
@@ -1073,9 +1038,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(

View File

@@ -123,7 +123,7 @@ def crawl_folders_for_files(
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}")
logger.info(f"Found file: {file['name']}, user email: {user_email}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,

View File

@@ -175,9 +175,12 @@ def _get_tickets_page(
)
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
def _fetch_author(
client: ZendeskClient, author_id: str | int
) -> BasicExpertInfo | None:
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
# cast to str to avoid issues with zendesk changing their types
if not author_id or str(author_id) == "-1":
return None
try:

View File

@@ -8,23 +8,31 @@ from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.server.documents.models import ConnectorCredentialPair
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt.id,
"status": IndexingStatus.IN_PROGRESS.value,
"cc_pair_id": index_attempt.connector_credential_pair_id,
"search_settings_id": index_attempt.search_settings_id,
},
)
except Exception:
db_session.rollback()
raise
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.SUCCESS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.CANCELED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -283,6 +342,20 @@ def mark_attempt_failed(
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": failure_reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,

View File

@@ -1475,7 +1475,6 @@ class DocumentRetrievalFeedback(Base):
feedback: Mapped[SearchFeedbackType | None] = mapped_column(
Enum(SearchFeedbackType, native_enum=False), nullable=True
)
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage",
back_populates="document_feedbacks",

View File

@@ -13,6 +13,37 @@ from onyx.setup import setup_logger
logger = setup_logger()
def fetch_paginated_sync_records(
db_session: Session,
entity_id: int,
sync_type: SyncType,
page_num: int,
page_size: int,
) -> tuple[list[SyncRecord], int]:
total_count = (
db_session.query(SyncRecord)
.filter(
SyncRecord.entity_id == entity_id,
SyncRecord.sync_type == sync_type,
)
.count()
)
sync_records = (
db_session.query(SyncRecord)
.filter(
SyncRecord.entity_id == entity_id,
SyncRecord.sync_type == sync_type,
)
.order_by(SyncRecord.sync_start_time.desc())
.offset(page_num * page_size)
.limit(page_size)
.all()
)
return sync_records, total_count
def insert_sync_record(
db_session: Session,
entity_id: int,

View File

@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
)
if explicit_override:
return
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(

View File

@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
prefix="/auth",
)
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,
fastapi_users.get_refresh_router(auth_backend),
prefix="/auth",
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)

View File

@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
# just gets the version of Onyx (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/refresh", {"POST"}),
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),

View File

@@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import fetch_user_group
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
@@ -41,6 +42,7 @@ from onyx.db.document import get_documents_for_cc_pair
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import count_index_attempts_for_connector
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
@@ -50,6 +52,7 @@ from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import fetch_paginated_sync_records
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CCPairFullInfo
@@ -60,6 +63,7 @@ from onyx.server.documents.models import ConnectorCredentialPairMetadata
from onyx.server.documents.models import DocumentSyncStatus
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from onyx.server.documents.models import SyncRecordSnapshot
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -69,6 +73,85 @@ logger = setup_logger()
router = APIRouter(prefix="/manage")
@router.get("/admin/user-group/{group_id}/sync-status")
def get_user_group_sync_status(
group_id: int,
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[SyncRecordSnapshot]:
user_group = fetch_user_group(db_session, group_id)
if not user_group:
raise HTTPException(status_code=404, detail="User group not found")
sync_records, total_count = fetch_paginated_sync_records(
db_session, group_id, SyncType.USER_GROUP, page_num, page_size
)
return PaginatedReturn(
items=[
SyncRecordSnapshot.from_sync_record_db_model(sync_record)
for sync_record in sync_records
],
total_items=total_count,
)
@router.post("/admin/user-group/{group_id}/sync")
def sync_user_group(
group_id: int,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[None]:
"""Triggers sync of a user group immediately"""
get_current_tenant_id()
user_group = fetch_user_group(db_session, group_id)
if not user_group:
raise HTTPException(status_code=404, detail="User group not found")
# Add logic to actually trigger the sync - this would depend on your implementation
# For example:
# try_creating_usergroup_sync_task(primary_app, group_id, get_redis_client(), tenant_id)
logger.info(f"User group sync queued: group_id={group_id}")
return StatusResponse(
success=True,
message="Successfully created the user group sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-status")
def get_cc_pair_sync_status(
cc_pair_id: int,
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[SyncRecordSnapshot]:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id, db_session, user, get_editable=False
)
if not cc_pair:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user permissions"
)
sync_records, total_count = fetch_paginated_sync_records(
db_session, cc_pair_id, SyncType.EXTERNAL_PERMISSIONS, page_num, page_size
)
return PaginatedReturn(
items=[
SyncRecordSnapshot.from_sync_record_db_model(sync_record)
for sync_record in sync_records
],
total_items=total_count,
)
@router.get("/admin/cc-pair/{cc_pair_id}/index-attempts")
def get_cc_pair_index_attempts(
cc_pair_id: int,

View File

@@ -14,12 +14,15 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document as DbDocument
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import SyncRecord
from onyx.db.models import TaskStatus
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
@@ -67,6 +70,26 @@ class ConnectorBase(BaseModel):
indexing_start: datetime | None = None
class SyncRecordSnapshot(BaseModel):
id: int
entity_id: int
sync_type: SyncType
created_at: datetime
num_docs_synced: int
sync_status: SyncStatus
@classmethod
def from_sync_record_db_model(cls, sync_record: SyncRecord) -> "SyncRecordSnapshot":
return cls(
id=sync_record.id,
entity_id=sync_record.entity_id,
sync_type=sync_record.sync_type,
created_at=sync_record.sync_start_time,
num_docs_synced=sync_record.num_docs_synced,
sync_status=sync_record.sync_status,
)
class ConnectorUpdateRequest(ConnectorBase):
access_type: AccessType
groups: list[int] = Field(default_factory=list)
@@ -194,6 +217,7 @@ PaginatedType = TypeVar(
InvitedUserSnapshot,
ChatSessionMinimal,
IndexAttemptErrorPydantic,
SyncRecordSnapshot,
)

View File

@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
class UserRoleUpdateRequest(BaseModel):
user_email: str
new_role: UserRole
explicit_override: bool = False
class UserRoleResponse(BaseModel):

View File

@@ -102,6 +102,7 @@ def set_user_role(
validate_user_role_update(
requested_role=requested_role,
current_role=current_role,
explicit_override=user_role_update_request.explicit_override,
)
if user_to_update.id == current_user.id:
@@ -122,6 +123,22 @@ def set_user_role(
db_session.commit()
class TestUpsertRequest(BaseModel):
email: str
@router.post("/manage/users/test-upsert-user")
async def test_upsert_user(
request: TestUpsertRequest,
_: User = Depends(current_admin_user),
) -> None | FullUserSnapshot:
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
user = await fetch_ee_implementation_or_noop(
"onyx.server.saml", "upsert_saml_user", None
)(email=request.email)
return FullUserSnapshot.from_user_model(user) if user else None
@router.get("/manage/users/accepted")
def list_accepted_users(
q: str | None = Query(default=None),

View File

@@ -36,6 +36,10 @@ class RecordType(str, Enum):
LATENCY = "latency"
FAILURE = "failure"
METRIC = "metric"
INDEXING_PROGRESS = "indexing_progress"
INDEXING_COMPLETE = "indexing_complete"
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
INDEX_ATTEMPT_STATUS = "index_attempt_status"
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:

View File

@@ -6,14 +6,17 @@ import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
@@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
@@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
max_workers: Max number of worker threads
Returns:
dict: A dictionary mapping function names to their results or error messages.
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
@@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@@ -288,7 +298,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@@ -304,9 +314,9 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return task
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R:

View File

@@ -123,10 +123,15 @@ class UserManager:
user_to_set: DATestUser,
target_role: UserRole,
user_performing_action: DATestUser,
explicit_override: bool = False,
) -> DATestUser:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/set-user-role",
json={"user_email": user_to_set.email, "new_role": target_role.value},
json={
"user_email": user_to_set.email,
"new_role": target_role.value,
"explicit_override": explicit_override,
},
headers=user_performing_action.headers,
)
response.raise_for_status()

View File

@@ -0,0 +1,90 @@
import requests
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_saml_user_conversion(reset: None) -> None:
"""
Test that SAML login correctly converts users with non-authenticated roles
(SLACK_USER or EXT_PERM_USER) to authenticated roles (BASIC).
This test:
1. Creates an admin and a regular user
2. Changes the regular user's role to EXT_PERM_USER
3. Simulates a SAML login by calling the test endpoint
4. Verifies the user's role is converted to BASIC
This tests the fix that ensures users with non-authenticated roles (SLACK_USER or EXT_PERM_USER)
are properly converted to authenticated roles during SAML login.
"""
# Create an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
# Create a regular user that we'll convert to EXT_PERM_USER
test_user_email = "ext_perm_user@example.com"
test_user = UserManager.create(email=test_user_email)
# Verify the user was created with BASIC role initially
assert UserManager.is_role(test_user, UserRole.BASIC)
# Change the user's role to EXT_PERM_USER using the UserManager
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify the user has EXT_PERM_USER role now
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login by calling the test endpoint
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": test_user_email},
headers=admin_user.headers, # Use admin headers for authorization
)
response.raise_for_status()
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify user role was changed in the database
assert UserManager.is_role(test_user, UserRole.BASIC)
# Do the same test with SLACK_USER
slack_user_email = "slack_user@example.com"
slack_user = UserManager.create(email=slack_user_email)
# Verify the user was created with BASIC role initially
assert UserManager.is_role(slack_user, UserRole.BASIC)
# Change the user's role to SLACK_USER
UserManager.set_role(
user_to_set=slack_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify the user has SLACK_USER role
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
# Simulate SAML login again
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": slack_user_email},
headers=admin_user.headers,
)
response.raise_for_status()
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify the user's role was changed in the database
assert UserManager.is_role(slack_user, UserRole.BASIC)

View File

@@ -0,0 +1,43 @@
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import pytest
from onyx.db.models import OAuthAccount
from onyx.db.models import User
@pytest.fixture
def mock_user() -> MagicMock:
"""Creates a mock User instance for testing."""
user = MagicMock(spec=User)
user.email = "test@example.com"
user.id = "test-user-id"
return user
@pytest.fixture
def mock_oauth_account() -> MagicMock:
"""Creates a mock OAuthAccount instance for testing."""
oauth_account = MagicMock(spec=OAuthAccount)
oauth_account.oauth_name = "google"
oauth_account.refresh_token = "test-refresh-token"
oauth_account.access_token = "test-access-token"
oauth_account.expires_at = None
return oauth_account
@pytest.fixture
def mock_user_manager() -> MagicMock:
"""Creates a mock user manager for testing."""
user_manager = MagicMock()
user_manager.user_db = MagicMock()
user_manager.user_db.update_oauth_account = AsyncMock()
user_manager.user_db.update = AsyncMock()
return user_manager
@pytest.fixture
def mock_db_session() -> MagicMock:
"""Creates a mock database session for testing."""
return MagicMock()

View File

@@ -0,0 +1,273 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.oauth_refresher import _test_expire_oauth_token
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
from onyx.auth.oauth_refresher import refresh_oauth_token
from onyx.db.models import OAuthAccount
@pytest.mark.asyncio
async def test_refresh_oauth_token_success(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test successful OAuth token refresh."""
# Mock HTTP client and response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new_token",
"refresh_token": "new_refresh_token",
"expires_in": 3600,
}
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Use fixture values but ensure refresh token exists
mock_oauth_account.oauth_name = (
"google" # Ensure it's google to match the refresh endpoint
)
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_instance = mock_client
client_class_mock.return_value.__aenter__.return_value = client_instance
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is True
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify token data was updated correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert update_data["access_token"] == "new_token"
assert update_data["refresh_token"] == "new_refresh_token"
assert "expires_at" in update_data
@pytest.mark.asyncio
async def test_refresh_oauth_token_failure(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> bool:
"""Test OAuth token refresh failure due to HTTP error."""
# Mock HTTP client with error response
mock_response = MagicMock()
mock_response.status_code = 400 # Simulate error
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Ensure refresh token exists and provider is supported
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_class_mock.return_value.__aenter__.return_value = mock_client
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_not_called()
return True
@pytest.mark.asyncio
async def test_refresh_oauth_token_no_refresh_token(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test OAuth token refresh when no refresh token is available."""
# Set refresh token to None
mock_oauth_account.refresh_token = None
mock_oauth_account.oauth_name = "google"
# No need to mock httpx since it shouldn't be called
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
@pytest.mark.asyncio
async def test_check_and_refresh_oauth_tokens(
mock_user: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test checking and refreshing multiple OAuth tokens."""
# Create mock user with OAuth accounts
now_timestamp = datetime.now(timezone.utc).timestamp()
# Create an account that needs refreshing (expiring soon)
expiring_account = MagicMock(spec=OAuthAccount)
expiring_account.oauth_name = "google"
expiring_account.refresh_token = "refresh_token_1"
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
# Create an account that doesn't need refreshing (expires later)
valid_account = MagicMock(spec=OAuthAccount)
valid_account.oauth_name = "google"
valid_account.refresh_token = "refresh_token_2"
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
# Create an account without a refresh token
no_refresh_account = MagicMock(spec=OAuthAccount)
no_refresh_account.oauth_name = "google"
no_refresh_account.refresh_token = None
no_refresh_account.expires_at = (
now_timestamp + 60
) # Expiring soon but no refresh token
# Set oauth_accounts on the mock user
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
# Mock refresh_oauth_token function
with patch(
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
) as mock_refresh:
# Call the function under test
await check_and_refresh_oauth_tokens(
mock_user, mock_db_session, mock_user_manager
)
# Assertions
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
# Check it was called with the expiring account
mock_refresh.assert_called_once_with(
mock_user, expiring_account, mock_db_session, mock_user_manager
)
@pytest.mark.asyncio
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
"""Test identifying OAuth accounts that need refresh tokens."""
# Create accounts with and without refresh tokens
account_with_token = MagicMock(spec=OAuthAccount)
account_with_token.oauth_name = "google"
account_with_token.refresh_token = "refresh_token"
account_without_token = MagicMock(spec=OAuthAccount)
account_without_token.oauth_name = "google"
account_without_token.refresh_token = None
second_account_without_token = MagicMock(spec=OAuthAccount)
second_account_without_token.oauth_name = "github"
second_account_without_token.refresh_token = (
"" # Empty string should also be treated as missing
)
# Set accounts on user
mock_user.oauth_accounts = [
account_with_token,
account_without_token,
second_account_without_token,
]
# Call the function under test
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
mock_user
)
# Assertions
assert len(accounts_needing_refresh) == 2
assert account_without_token in accounts_needing_refresh
assert second_account_without_token in accounts_needing_refresh
assert account_with_token not in accounts_needing_refresh
@pytest.mark.asyncio
async def test_check_oauth_account_has_refresh_token(
mock_user: MagicMock, mock_oauth_account: MagicMock
) -> None:
"""Test checking if an OAuth account has a refresh token."""
# Test with refresh token
mock_oauth_account.refresh_token = "refresh_token"
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is True
# Test with None refresh token
mock_oauth_account.refresh_token = None
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
# Test with empty string refresh token
mock_oauth_account.refresh_token = ""
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
@pytest.mark.asyncio
async def test_test_expire_oauth_token(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test the testing utility function for token expiration."""
# Set up the mock account
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "test_refresh_token"
mock_oauth_account.access_token = "test_access_token"
# Call the function under test
result = await _test_expire_oauth_token(
mock_user,
mock_oauth_account,
mock_db_session,
mock_user_manager,
expire_in_seconds=10,
)
# Assertions
assert result is True
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify the expiration time was set correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert "expires_at" in update_data
# Now should be within 10-11 seconds of the set expiration
now = datetime.now(timezone.utc).timestamp()
assert update_data["expires_at"] - now >= 8.9 # Allow 1 second for test execution
assert update_data["expires_at"] - now <= 11.1 # Allow 1 second for test execution

View File

@@ -89,7 +89,8 @@ def test_run_in_background_and_wait_success() -> None:
elapsed = time.time() - start_time
assert result == 42
assert elapsed >= 0.1 # Verify we actually waited for the sleep
# sometimes slightly flaky
assert elapsed >= 0.095 # Verify we actually waited for the sleep
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")

View File

@@ -5,7 +5,7 @@ envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/con
echo "Waiting for API server to boot up; this may take a minute or two..."
echo "If this takes more than ~5 minutes, check the logs of the API server container for errors with the following command:"
echo
echo "docker logs onyx-stack_api_server-1"
echo "docker logs onyx-stack-api_server-1"
echo
while true; do

1
openapi.json Normal file

File diff suppressed because one or more lines are too long

12551
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -93,11 +93,12 @@
},
"devDependencies": {
"@chromatic-com/playwright": "^0.10.2",
"@playwright/test": "^1.39.0",
"@tailwindcss/typography": "^0.5.10",
"@types/chrome": "^0.0.287",
"@types/jest": "^29.5.14",
"chromatic": "^11.25.2",
"eslint": "^8.48.0",
"eslint": "^8.57.1",
"eslint-config-next": "^14.1.0",
"jest": "^29.7.0",
"prettier": "2.8.8",

View File

@@ -1079,7 +1079,7 @@ export function AssistantEditor({
</Tooltip>
</TooltipProvider>
<span className="text-sm ml-2">
{values.is_public ? "Public" : "Private"}
Organization Public
</span>
</div>
@@ -1088,17 +1088,22 @@ export function AssistantEditor({
<InfoIcon size={16} className="mr-2" />
<span className="text-sm">
Default persona must be public. Visibility has been
automatically set to public.
automatically set to organization public.
</span>
</div>
)}
{values.is_public ? (
<p className="text-sm text-text-dark">
Anyone from your team can view and use this assistant
This assistant will be available to everyone in your
organization
</p>
) : (
<>
<p className="text-sm text-text-dark mb-2">
This assistant will only be available to specific
users and groups
</p>
<div className="mt-2">
<Label className="mb-2" small>
Share with Users and Groups

View File

@@ -13,7 +13,7 @@ import Text from "@/components/ui/text";
import { Callout } from "@/components/ui/callout";
import { CCPairFullInfo } from "./types";
import { IndexAttemptSnapshot } from "@/lib/types";
import { IndexAttemptStatus } from "@/components/Status";
import { AttemptStatus } from "@/components/Status";
import { PageSelector } from "@/components/PageSelector";
import { ThreeDotsLoader } from "@/components/Loading";
import { buildCCPairInfoUrl } from "./lib";
@@ -118,7 +118,7 @@ export function IndexingAttemptsTable({
: "-"}
</TableCell>
<TableCell>
<IndexAttemptStatus
<AttemptStatus
status={indexAttempt.status || "not_started"}
/>
{docsPerMinute ? (

View File

@@ -0,0 +1,148 @@
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { CCPairFullInfo, SyncRecord } from "./types";
import { Button } from "@/components/ui/button";
import {
ChevronDown,
ChevronLeft,
ChevronRight,
ChevronUp,
CheckCircle,
} from "lucide-react";
import { useState } from "react";
import { AttemptStatus } from "@/components/Status";
import { PageSelector } from "@/components/PageSelector";
// Helper function to format date
const formatDateTime = (date: Date): string => {
return new Intl.DateTimeFormat("en-US", {
year: "numeric",
month: "short",
day: "numeric",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
}).format(date);
};
interface SyncStatus {
status: string;
}
interface SyncStatusTableProps {
ccPair: CCPairFullInfo;
syncRecords: SyncRecord[];
currentPage: number;
totalPages: number;
onPageChange: (page: number) => void;
}
export function SyncStatusTable({
ccPair,
syncRecords,
currentPage,
totalPages,
onPageChange,
}: SyncStatusTableProps) {
const hasPagination = totalPages > 1;
const [isExpanded, setIsExpanded] = useState(false);
// Filter to only show external permissions
const filteredRecords =
syncRecords?.filter(
(record) => record.sync_type === "external_permissions"
) || [];
// Check if we have any records to show
if (filteredRecords.length === 0 && currentPage === 0) {
return (
<div className="text-sm text-gray-500 italic my-2">
No permissions sync records found
</div>
);
}
// The total documents synced is the total number of documents in the cc_pair
const totalDocsSynced = ccPair.num_docs_indexed || 0;
return (
<div className="mb-6">
<div className="flex items-center mb-2">
<Button
variant="ghost"
size="sm"
onClick={() => setIsExpanded(!isExpanded)}
className="mr-2 text-muted-foreground"
>
{isExpanded ? (
<ChevronUp className="h-4 w-4" />
) : (
<ChevronDown className="h-4 w-4" />
)}
{isExpanded ? "Hide details" : "Show sync history"}
</Button>
<div className="text-sm">
<span className="flex items-center text-green-600 dark:text-green-400">
<CheckCircle className="h-4 w-4 mr-1 inline" />
Permissions sync enabled {totalDocsSynced} documents synced
</span>
</div>
</div>
{isExpanded && (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead>Time</TableHead>
<TableHead>Status</TableHead>
<TableHead></TableHead>
<TableHead>Docs Processed</TableHead>
<TableHead>Error</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{filteredRecords.length === 0 ? (
<TableRow>
<TableCell colSpan={3} className="text-center py-4">
No permission sync records found
</TableCell>
</TableRow>
) : (
filteredRecords.map((record, index) => (
<TableRow key={record.id}>
<TableCell>
{formatDateTime(new Date(record.created_at))}
</TableCell>
<TableCell>
<AttemptStatus status={record.sync_status as any} />
</TableCell>
<TableCell></TableCell>
<TableCell>{record.num_docs_synced}</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
{hasPagination && (
<div className="flex justify-end mt-4">
<PageSelector
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
/>
</div>
)}
</>
)}
</div>
);
}

View File

@@ -23,6 +23,7 @@ import { AdvancedConfigDisplay, ConfigDisplay } from "./ConfigDisplay";
import { DeletionButton } from "./DeletionButton";
import DeletionErrorStatus from "./DeletionErrorStatus";
import { IndexingAttemptsTable } from "./IndexingAttemptsTable";
import { SyncStatusTable } from "./SyncStatusTable";
import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster";
import { ReIndexButton } from "./ReIndexButton";
import { buildCCPairInfoUrl, triggerIndexing } from "./lib";
@@ -32,6 +33,7 @@ import {
ConnectorCredentialPairStatus,
IndexAttemptError,
PaginatedIndexAttemptErrors,
SyncRecord,
} from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
@@ -90,13 +92,25 @@ function Main({ ccPairId }: { ccPairId: number }) {
endpoint: `${buildCCPairInfoUrl(ccPairId)}/index-attempts`,
});
const {
currentPageData: syncStatus,
isLoading: isLoadingSyncStatus,
currentPage: syncStatusCurrentPage,
totalPages: syncStatusTotalPages,
goToPage: goToSyncStatusPage,
} = usePaginatedFetch<SyncRecord>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: `${buildCCPairInfoUrl(ccPairId)}/sync-status`,
});
const {
currentPageData: indexAttemptErrorsPage,
currentPage: errorsCurrentPage,
totalPages: errorsTotalPages,
goToPage: goToErrorsPage,
} = usePaginatedFetch<IndexAttemptError>({
itemsPerPage: 10,
itemsPerPage: 300,
pagesPerBatch: 1,
endpoint: `/api/manage/admin/cc-pair/${ccPairId}/errors`,
});
@@ -450,6 +464,52 @@ function Main({ ccPairId }: { ccPairId: number }) {
/>
)}
<div className="mt-6">
<div className="flex justify-between items-center">
<Title>Permissions Sync</Title>
{ccPair.is_editable_for_current_user && (
<Button
variant="outline"
size="sm"
onClick={async () => {
try {
const response = await fetch(
`/api/manage/admin/cc-pair/${ccPair.id}/sync-permissions`,
{
method: "POST",
}
);
if (!response.ok) {
throw new Error(await response.text());
}
setPopup({
message: "Permissions sync started successfully",
type: "success",
});
mutate(buildCCPairInfoUrl(ccPairId));
} catch (error) {
setPopup({
message: "Failed to start permissions sync",
type: "error",
});
}
}}
disabled={ccPair.status !== ConnectorCredentialPairStatus.ACTIVE}
>
Sync Now
</Button>
)}
</div>
{syncStatus && (
<SyncStatusTable
ccPair={ccPair}
syncRecords={syncStatus}
currentPage={syncStatusCurrentPage}
totalPages={syncStatusTotalPages}
onPageChange={goToSyncStatusPage}
/>
)}
</div>
<div className="mt-6">
<div className="flex">
<Title>Indexing Attempts</Title>

View File

@@ -74,3 +74,11 @@ export interface PaginatedIndexAttemptErrors {
items: IndexAttemptError[];
total_items: number;
}
export interface SyncRecord {
id: number;
entity_id: number;
sync_type: string;
created_at: string;
num_docs_synced: number;
sync_status: ValidStatuses;
}

View File

@@ -9,7 +9,7 @@ import {
} from "@/components/ui/table";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { IndexAttemptStatus } from "@/components/Status";
import { AttemptStatus } from "@/components/Status";
import { timeAgo } from "@/lib/time";
import {
ConnectorIndexingStatus,
@@ -244,7 +244,7 @@ border border-border dark:border-neutral-700
)}
<TableCell>{ccPairsIndexingStatus.docs_indexed}</TableCell>
<TableCell>
<IndexAttemptStatus
<AttemptStatus
status={ccPairsIndexingStatus.last_finished_status || null}
errorMsg={ccPairsIndexingStatus?.latest_index_attempt?.error_msg}
/>

View File

@@ -100,7 +100,10 @@ export function EmailPasswordForm({
// server-side provider values)
window.location.href = "/auth/waiting-on-verification";
} else {
// See above comment
// The searchparam is purely for multi tenant developement purposes.
// It replicates the behavior of the case where a user
// has signed up with email / password as the only user to an instance
// and has just completed verification
window.location.href = nextUrl
? encodeURI(nextUrl)
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;

View File

@@ -7,7 +7,7 @@ import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
import { User } from "@/lib/types";
import { Logo } from "@/components/logo/Logo";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
export function Verify({ user }: { user: User | null }) {
const searchParams = useSearchParams();
const router = useRouter();
@@ -16,6 +16,8 @@ export function Verify({ user }: { user: User | null }) {
const verify = useCallback(async () => {
const token = searchParams.get("token");
const firstUser =
searchParams.get("first_user") && NEXT_PUBLIC_CLOUD_ENABLED;
if (!token) {
setError(
"Missing verification token. Try requesting a new verification email."
@@ -35,7 +37,7 @@ export function Verify({ user }: { user: User | null }) {
// Use window.location.href to force a full page reload,
// ensuring app re-initializes with the new state (including
// server-side provider values)
window.location.href = "/";
window.location.href = firstUser ? "/chat?new_team=true" : "/chat";
} else {
const errorDetail = (await response.json()).detail;
setError(

View File

@@ -1158,6 +1158,7 @@ export function ChatPage({
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
setUncaughtError(null);
setLoadingError(null);
// Mark that we've sent a message for this session in the current page load
markSessionMessageSent(frozenSessionId);

View File

@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
import { unstable_noStore as noStore } from "next/cache";
import { fetchChatData } from "@/lib/chat/fetchChatData";
import { ChatProvider } from "@/components/context/ChatContext";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
export default async function Layout({
children,
@@ -41,7 +40,6 @@ export default async function Layout({
return (
<>
<InstantSSRAutoRefresh />
<ChatProvider
value={{
proSearchToggled,

View File

@@ -1,7 +1,7 @@
import { Modal } from "@/components/Modal";
import { updateUserGroup } from "./lib";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { User, UserGroup } from "@/lib/types";
import { CCPairDescriptor, User, UserGroup } from "@/lib/types";
import { UserEditor } from "../UserEditor";
import { useState } from "react";
@@ -42,10 +42,14 @@ export const AddMemberForm: React.FC<AddMemberFormProps> = ({
)
),
];
const response = await updateUserGroup(userGroup.id, {
user_ids: newUserIds,
cc_pair_ids: userGroup.cc_pairs.map((ccPair) => ccPair.id),
cc_pair_ids: userGroup.cc_pairs.map(
(ccPair: CCPairDescriptor<number, number>) => ccPair.id
),
});
if (response.ok) {
setPopup({
message: "Successfully added users to group",

View File

@@ -46,6 +46,10 @@ import { AddTokenRateLimitForm } from "./AddTokenRateLimitForm";
import { GenericTokenRateLimitTable } from "@/app/admin/token-rate-limits/TokenRateLimitTables";
import { useUser } from "@/components/user/UserProvider";
import { GenericConfirmModal } from "@/components/modals/GenericConfirmModal";
import { GroupSyncStatusTable } from "./GroupSyncStatusTable";
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
import Title from "@/components/ui/title";
import { SyncRecord } from "@/app/admin/connector/[ccPairId]/types";
interface GroupDisplayProps {
users: User[];
@@ -178,6 +182,18 @@ export const GroupDisplay = ({
const { isAdmin } = useUser();
const {
currentPageData: syncStatus,
isLoading: isLoadingSyncStatus,
currentPage: syncStatusCurrentPage,
totalPages: syncStatusTotalPages,
goToPage: goToSyncStatusPage,
} = usePaginatedFetch<SyncRecord>({
itemsPerPage: 8,
pagesPerBatch: 8,
endpoint: `/api/manage/admin/user-group/${userGroup.id}/sync-status`,
});
const handlePopup = (message: string, type: "success" | "error") => {
setPopup({ message, type });
};
@@ -204,6 +220,22 @@ export const GroupDisplay = ({
<Separator />
<div className="mt-4 mb-4">
<Title>Group Sync Status</Title>
{syncStatus && (
<GroupSyncStatusTable
userGroup={userGroup}
syncRecords={syncStatus}
currentPage={syncStatusCurrentPage}
totalPages={syncStatusTotalPages}
onPageChange={goToSyncStatusPage}
lastSyncTime={userGroup.last_synced_at}
/>
)}
</div>
<Separator />
<div className="flex w-full">
<h2 className="text-xl font-bold">Users</h2>
</div>

View File

@@ -0,0 +1,145 @@
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Button } from "@/components/ui/button";
import { ChevronDown, ChevronUp, CheckCircle, Users } from "lucide-react";
import { useState } from "react";
import { PageSelector } from "@/components/PageSelector";
import { UserGroup } from "@/lib/types";
import { SyncRecord } from "@/app/admin/connector/[ccPairId]/types";
// Helper function to format date
const formatDateTime = (date: Date): string => {
return new Intl.DateTimeFormat("en-US", {
year: "numeric",
month: "short",
day: "numeric",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
}).format(date);
};
interface GroupSyncStatusTableProps {
userGroup: UserGroup;
syncRecords: SyncRecord[];
currentPage: number;
totalPages: number;
onPageChange: (page: number) => void;
lastSyncTime?: string | null;
}
export function GroupSyncStatusTable({
userGroup,
syncRecords,
currentPage,
totalPages,
onPageChange,
lastSyncTime,
}: GroupSyncStatusTableProps) {
const hasPagination = totalPages > 1;
const [isExpanded, setIsExpanded] = useState(false);
// Filter to only show user group syncs
const filteredRecords =
syncRecords?.filter((record) => record.sync_type === "user_group") || [];
// Check if we have any records to show
if (filteredRecords.length === 0 && currentPage === 0) {
return (
<div className="text-sm text-gray-500 italic my-2">
No group sync records found
</div>
);
}
// Estimate the total number of users synced
const totalUsersSynced = userGroup.users?.length || 0;
return (
<div className="mb-6">
<div className="flex items-center justify-between mb-2">
<div className="flex items-center">
<Button
variant="ghost"
size="sm"
onClick={() => setIsExpanded(!isExpanded)}
className="mr-2 text-muted-foreground"
>
{isExpanded ? (
<ChevronUp className="h-4 w-4" />
) : (
<ChevronDown className="h-4 w-4" />
)}
{isExpanded ? "Hide details" : "Show sync history"}
</Button>
<div className="text-sm">
<span className="flex items-center text-green-600 dark:text-green-400">
<CheckCircle className="h-4 w-4 mr-1 inline" />
Groups synced
</span>
</div>
</div>
</div>
{lastSyncTime && (
<div className="text-sm text-muted-foreground mb-2">
Last synced: {new Date(lastSyncTime).toLocaleString()}
</div>
)}
{isExpanded && (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead>Time</TableHead>
<TableHead>Status</TableHead>
<TableHead>Docs Processed</TableHead>
<TableHead>Error</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{filteredRecords.length === 0 ? (
<TableRow>
<TableCell colSpan={3} className="text-center py-4">
No group sync records found
</TableCell>
</TableRow>
) : (
filteredRecords.map((record) => (
<TableRow key={record.id}>
<TableCell>
{formatDateTime(new Date(record.created_at))}
</TableCell>
<TableCell>
<span className="flex items-center text-green-600 dark:text-green-400">
<CheckCircle className="h-4 w-4 mr-1" /> Success
</span>
</TableCell>
<TableCell>{record.num_docs_synced}</TableCell>
</TableRow>
))
)}
</TableBody>
</Table>
{hasPagination && (
<div className="flex justify-end mt-4">
<PageSelector
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
/>
</div>
)}
</>
)}
</div>
);
}

View File

@@ -12,7 +12,7 @@ import {
import { HoverPopup } from "./HoverPopup";
import { ConnectorCredentialPairStatus } from "@/app/admin/connector/[ccPairId]/types";
export function IndexAttemptStatus({
export function AttemptStatus({
status,
errorMsg,
}: {

View File

@@ -1,6 +1,6 @@
import { buildCCPairInfoUrl } from "@/app/admin/connector/[ccPairId]/lib";
import { PageSelector } from "@/components/PageSelector";
import { IndexAttemptStatus } from "@/components/Status";
import { AttemptStatus } from "@/components/Status";
import { deleteCCPair } from "@/lib/documentDeletion";
import { FailedConnectorIndexingStatus } from "@/lib/types";
import { Button } from "@/components/ui/button";
@@ -76,7 +76,7 @@ export function FailedReIndexAttempts({
</Link>
</TableCell>
<TableCell>
<IndexAttemptStatus status="failed" />
<AttemptStatus status="failed" />
</TableCell>
<TableCell>

View File

@@ -1,5 +1,5 @@
import { PageSelector } from "@/components/PageSelector";
import { IndexAttemptStatus } from "@/components/Status";
import { AttemptStatus } from "@/components/Status";
import { ConnectorIndexingStatus } from "@/lib/types";
import {
Table,
@@ -49,7 +49,7 @@ export function ReindexingProgressTable({
</TableCell>
<TableCell>
{reindexingProgress.latest_index_attempt?.status && (
<IndexAttemptStatus
<AttemptStatus
status={reindexingProgress.latest_index_attempt.status}
/>
)}

View File

@@ -1,11 +1,18 @@
"use client";
import React, { createContext, useContext, useState, useEffect } from "react";
import React, {
createContext,
useContext,
useState,
useEffect,
useRef,
} from "react";
import { User, UserRole } from "@/lib/types";
import { getCurrentUser } from "@/lib/user";
import { usePostHog } from "posthog-js/react";
import { CombinedSettings } from "@/app/admin/settings/interfaces";
import { SettingsContext } from "../settings/SettingsProvider";
import { useTokenRefresh } from "@/hooks/useTokenRefresh";
interface UserContextType {
user: User | null;
@@ -93,6 +100,10 @@ export function UserProvider({
console.error("Error fetching current user:", error);
}
};
// Use the custom token refresh hook
useTokenRefresh(upToDateUser, fetchUser);
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
try {
setUpToDateUser((prevUser) => {

View File

@@ -27,6 +27,7 @@ interface PaginationConfig {
query?: string;
filter?: Record<string, string | boolean | number | string[] | Date>;
refreshIntervalInMs?: number;
zeroIndexed?: boolean; // Flag to indicate if backend uses zero-indexed pages
}
interface PaginatedHookReturnData<T extends PaginatedType> {
@@ -46,15 +47,18 @@ function usePaginatedFetch<T extends PaginatedType>({
query,
filter,
refreshIntervalInMs = 5000,
zeroIndexed = true, // Default to true for zero-indexed pages
}: PaginationConfig): PaginatedHookReturnData<T> {
const router = useRouter();
const currentPath = usePathname();
const searchParams = useSearchParams();
// State to initialize and hold the current page number
const [currentPage, setCurrentPage] = useState(() =>
parseInt(searchParams?.get("page") || "1", 10)
);
const [currentPage, setCurrentPage] = useState(() => {
const urlPage = parseInt(searchParams?.get("page") || "1", 10);
// For URL display we use 1-indexed, but for internal state we use the appropriate index based on API
return zeroIndexed ? urlPage - 1 : urlPage;
});
const [currentPageData, setCurrentPageData] = useState<T[] | null>(null);
const [error, setError] = useState<Error | null>(null);
const [isLoading, setIsLoading] = useState<boolean>(false);
@@ -73,10 +77,11 @@ function usePaginatedFetch<T extends PaginatedType>({
// Calculates which batch we're in, and which page within that batch
const batchAndPageIndices = useMemo(() => {
const batchNum = Math.floor((currentPage - 1) / pagesPerBatch);
const batchPageNum = (currentPage - 1) % pagesPerBatch;
const pageForCalc = zeroIndexed ? currentPage : currentPage - 1;
const batchNum = Math.floor(pageForCalc / pagesPerBatch);
const batchPageNum = pageForCalc % pagesPerBatch;
return { batchNum, batchPageNum };
}, [currentPage, pagesPerBatch]);
}, [currentPage, pagesPerBatch, zeroIndexed]);
// Fetches a batch of data and stores it in the cache
const fetchBatchData = useCallback(
@@ -88,9 +93,9 @@ function usePaginatedFetch<T extends PaginatedType>({
ongoingRequestsRef.current.add(batchNum);
try {
// Build query params
// Build query params - use zero-based indexing for backend
const params = new URLSearchParams({
page_num: batchNum.toString(),
page_num: (batchNum * pagesPerBatch).toString(),
page_size: (pagesPerBatch * itemsPerPage).toString(),
});
@@ -145,27 +150,31 @@ function usePaginatedFetch<T extends PaginatedType>({
[endpoint, pagesPerBatch, itemsPerPage, query, filter]
);
// Updates the URL with the current page number
// Updates the URL with the current page number (always 1-indexed for URL)
const updatePageUrl = useCallback(
(page: number) => {
if (currentPath) {
const params = new URLSearchParams(searchParams);
params.set("page", page.toString());
// For URL display we always use 1-indexed
const urlPage = zeroIndexed ? page + 1 : page;
params.set("page", urlPage.toString());
router.replace(`${currentPath}?${params.toString()}`, {
scroll: false,
});
}
},
[currentPath, router, searchParams]
[currentPath, router, searchParams, zeroIndexed]
);
// Updates the current page
const goToPage = useCallback(
(newPage: number) => {
setCurrentPage(newPage);
updatePageUrl(newPage);
// Ensure page is within bounds
const boundedPage = Math.max(0, Math.min(newPage, totalPages - 1));
setCurrentPage(boundedPage);
updatePageUrl(boundedPage);
},
[updatePageUrl]
[updatePageUrl, totalPages]
);
// Loads the current and adjacent batches
@@ -199,7 +208,14 @@ function usePaginatedFetch<T extends PaginatedType>({
if (!cachedBatches[0]) {
fetchBatchData(0);
}
}, [currentPage, cachedBatches, totalPages, pagesPerBatch, fetchBatchData]);
}, [
currentPage,
cachedBatches,
totalPages,
pagesPerBatch,
fetchBatchData,
batchAndPageIndices,
]);
// Updates current page data from the cache
useEffect(() => {
@@ -209,7 +225,7 @@ function usePaginatedFetch<T extends PaginatedType>({
setCurrentPageData(cachedBatches[batchNum][batchPageNum]);
setIsLoading(false);
}
}, [currentPage, cachedBatches, pagesPerBatch]);
}, [currentPage, cachedBatches, pagesPerBatch, batchAndPageIndices]);
// Implements periodic refresh
useEffect(() => {
@@ -221,21 +237,28 @@ function usePaginatedFetch<T extends PaginatedType>({
}, refreshIntervalInMs);
return () => clearInterval(interval);
}, [currentPage, pagesPerBatch, refreshIntervalInMs, fetchBatchData]);
}, [
currentPage,
pagesPerBatch,
refreshIntervalInMs,
fetchBatchData,
batchAndPageIndices,
]);
// Manually refreshes the current batch
const refresh = useCallback(async () => {
const { batchNum } = batchAndPageIndices;
await fetchBatchData(batchNum);
}, [currentPage, pagesPerBatch, fetchBatchData]);
}, [batchAndPageIndices, fetchBatchData]);
// Cache invalidation
useEffect(() => {
setCachedBatches({});
setTotalItems(0);
goToPage(1);
// Start at page 0 for zero-indexed APIs, page 1 for one-indexed
goToPage(zeroIndexed ? 0 : 1);
setError(null);
}, [currentPath, query, filter]);
}, [currentPath, query, filter, zeroIndexed, goToPage]);
return {
currentPage,

View File

@@ -0,0 +1,84 @@
import { useState, useEffect, useRef } from "react";
import { User } from "@/lib/types";
import { NO_AUTH_USER_ID } from "@/lib/extension/constants";
// Refresh token every 10 minutes (600000ms)
// This is shorter than the session expiry time to ensure tokens stay valid
const REFRESH_INTERVAL = 600000;
// Custom hook for handling JWT token refresh for current user
export function useTokenRefresh(
user: User | null,
onRefreshFail: () => Promise<void>
) {
// Track last refresh time to avoid unnecessary calls
const [lastTokenRefresh, setLastTokenRefresh] = useState<number>(Date.now());
// Use a ref to track first load
const isFirstLoad = useRef(true);
useEffect(() => {
if (!user || user.id === NO_AUTH_USER_ID) return;
const refreshTokenPeriodically = async () => {
try {
// Skip time check if this is first load - we always refresh on first load
const isTimeToRefresh =
isFirstLoad.current ||
Date.now() - lastTokenRefresh > REFRESH_INTERVAL - 60000;
if (!isTimeToRefresh) {
return;
}
// Reset first load flag
if (isFirstLoad.current) {
isFirstLoad.current = false;
}
const response = await fetch("/api/auth/refresh", {
method: "POST",
credentials: "include",
});
if (response.ok) {
// Update last refresh time on success
setLastTokenRefresh(Date.now());
console.debug("Auth token refreshed successfully");
} else {
console.warn("Failed to refresh auth token:", response.status);
// If token refresh fails, try to get current user info
await onRefreshFail();
}
} catch (error) {
console.error("Error refreshing auth token:", error);
}
};
// Always attempt to refresh on first component mount
// This helps ensure tokens are fresh, especially after browser refresh
refreshTokenPeriodically();
// Set up interval for periodic refreshes
const intervalId = setInterval(refreshTokenPeriodically, REFRESH_INTERVAL);
// Also refresh token on window focus, but no more than once per minute
const handleVisibilityChange = () => {
if (
document.visibilityState === "visible" &&
Date.now() - lastTokenRefresh > 60000
) {
refreshTokenPeriodically();
}
};
document.addEventListener("visibilitychange", handleVisibilityChange);
return () => {
clearInterval(intervalId);
document.removeEventListener("visibilitychange", handleVisibilityChange);
};
}, [user, lastTokenRefresh, onRefreshFail]);
return { lastTokenRefresh };
}

View File

@@ -35,3 +35,5 @@ export const LocalStorageKeys = {
export const SEARCH_PARAMS = {
DEFAULT_SIDEBAR_OFF: "defaultSidebarOff",
};
export const NO_AUTH_USER_ID = "__no_auth_user__";

View File

@@ -342,6 +342,7 @@ export interface UserGroup {
personas: Persona[];
is_up_to_date: boolean;
is_up_for_deletion: boolean;
last_synced_at: string | null;
}
export enum ValidSources {