Compare commits

..

21 Commits

Author SHA1 Message Date
pablonyx
63d6931bf9 slight robustification 2025-03-26 16:48:20 -07:00
rkuo-danswer
bc9b4e4f45 use slack's built in rate limit handler for the bot (#4362)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-26 21:55:04 +00:00
evan-danswer
178a64f298 fix issue with drive connector service account indexing (#4356)
* fix issue with drive connector service account indexing

* correct checkpoint resumption

* final set of fixes

* nit

* fix typing

* logging and CW comments

* nit
2025-03-26 20:54:26 +00:00
pablonyx
c79f1edf1d add a flush (#4361) 2025-03-26 14:40:52 -07:00
pablonyx
7c8e23aa54 Fix saml conversion from ext_perm -> basic (#4343)
* fix saml conversion from ext_perm -> basic

* quick nit

* minor fix

* finalize

* update

* quick fix
2025-03-26 20:36:51 +00:00
pablonyx
d37b427d52 fix email flow (#4339) 2025-03-26 18:59:12 +00:00
pablonyx
a65fefd226 test fix 2025-03-26 12:43:38 -07:00
rkuo-danswer
bb09bde519 Bugfix/google drive size threshold 2 (#4355) 2025-03-26 12:06:36 -07:00
Tim Rosenblatt
0f6cf0fc58 Fixes docker logs helper text in run-nginx.sh (#3678)
The docker container name is slightly wrong, and this commit fixes it.
2025-03-26 09:03:35 -07:00
pablonyx
fed06b592d Auto refresh credentials (#4268)
* Auto refresh credentials

* remove dupes

* clean up + tests

* k

* quick nit

* add brief comment

* misc typing
2025-03-26 01:53:31 +00:00
pablonyx
8d92a1524e fix invitation on cloud (#4351)
* fix invitation on cloud

* k
2025-03-26 01:25:17 +00:00
pablonyx
ecfea9f5ed Email formatting devices (#4353)
* update email formatting

* k

* update

* k

* nit
2025-03-25 21:42:32 +00:00
rkuo-danswer
b269f1ba06 fix broken function call (#4354)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 21:07:31 +00:00
pablonyx
30c878efa5 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-03-25 18:39:55 +00:00
pablonyx
2024776c19 Respect contextvars when parallelizing for Google Drive (#4291)
* k

* k

* fix typing
2025-03-25 17:40:12 +00:00
pablonyx
431316929c k (#4336) 2025-03-25 17:00:35 +00:00
pablonyx
c5b9c6e308 update (#4344) 2025-03-25 16:56:23 +00:00
pablonyx
73dd188b3f update (#4338) 2025-03-25 16:55:25 +00:00
evan-danswer
79b061abbc Daylight savings time handling (#4345)
* confluence timezone improvements

* confluence timezone improvements
2025-03-25 16:11:30 +00:00
rkuo-danswer
552f1ead4f use correct namespace in redis for certain keys (#4340)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 04:10:31 +00:00
evan-danswer
17925b49e8 typing fix (#4342)
* typing fix

* changed type hint to help future coders
2025-03-25 01:01:13 +00:00
43 changed files with 1381 additions and 212 deletions

View File

@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging

View File

@@ -28,6 +28,20 @@ depends_on = None
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""

View File

@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,

View File

@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
# Define non-authenticated user roles that should be re-created during SAML login
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
return await user_manager.get_by_email(email)
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if user.role in NON_AUTHENTICATED_ROLES:
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:
logger.notice("Creating user from SAML login")
logger.info("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user: User = await user_manager.create(
user = await user_manager.create(
UserCreate(
email=email,
password=hashed_pass,
is_verified=True,
role=role,
)
)

View File

@@ -16,10 +16,10 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
@@ -62,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
}}
.header img {{
max-width: 140px;
width: 140px;
height: auto;
filter: brightness(1.1) contrast(1.2);
border-radius: 8px;
padding: 5px;
}}
.body-content {{
padding: 20px 30px;
@@ -78,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
padding: 14px 24px;
background-color: #0055FF;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
font-weight: 600;
font-size: 16px;
margin-top: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
text-align: center;
}}
.footer {{
font-size: 13px;
@@ -166,6 +175,7 @@ def send_email(
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
@@ -174,17 +184,30 @@ def send_email(
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
msg.attach(part_text)
msg.attach(part_html)
# Add text part first (lowest priority)
text_part = MIMEText(text_body, "plain")
msg.attach(text_part)
if inline_png:
# For HTML with images, create a multipart/related container
related = MIMEMultipart("related")
# Add the HTML part to the related container
html_part = MIMEText(html_body, "html")
related.attach(html_part)
# Add image with proper Content-ID to the related container
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", inline_png[0]) # CID reference
img.add_header("Content-ID", f"<{inline_png[0]}>")
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
msg.attach(img)
related.attach(img)
# Add the related part to the message (higher priority than text)
msg.attach(related)
else:
# No images, just add HTML directly (higher priority than text)
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -332,17 +355,23 @@ def send_forgot_password_email(
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
subject = f"Reset Your {application_name} Password"
heading = "Reset Your Password"
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
cta_text = "Reset Password"
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
html_content = build_html_email(
application_name,
"Reset Your Password",
heading,
message,
cta_text,
cta_link,
)
text_content = (
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
)
text_content = f"Click the following link to reset your password: {link}"
send_email(
user_email,
subject,
@@ -356,6 +385,7 @@ def send_forgot_password_email(
def send_user_verification_email(
user_email: str,
token: str,
new_organization: bool = False,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
@@ -372,6 +402,8 @@ def send_user_verification_email(
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
if new_organization:
link = add_url_params(link, {"first_user": "true"})
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)

View File

@@ -0,0 +1,211 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -5,12 +5,16 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -581,8 +585,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
user_count = await get_user_count()
send_user_verification_email(
user.email, token, new_organization=user_count == 1
)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
@@ -688,16 +694,20 @@ cookie_transport = CookieTransport(
)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -756,6 +766,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -806,6 +885,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -1039,12 +1200,20 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(

View File

@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")

View File

@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -235,7 +236,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -451,6 +451,8 @@ def monitor_connector_deletion_taskset(
credential_id=cc_pair.credential_id,
)
db_session.flush()
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,

View File

@@ -68,6 +68,8 @@ from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -875,6 +877,21 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"id": payload.id if payload else None,
"total_docs": initial if initial is not None else 0,
"remaining_docs": remaining,
"synced_docs": (initial - remaining) if initial is not None else 0,
"is_complete": remaining == 0,
},
tenant_id=tenant_id,
)
if remaining > 0:
return

View File

@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -570,6 +573,22 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_num": batch_num,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
@@ -585,6 +604,30 @@ def _run_indexing(
checkpoint=checkpoint,
)
# Add telemetry for completed indexing
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
redis_connector_index = redis_connector.new_index(
index_attempt_start.search_settings_id
)
final_progress = redis_connector_index.get_progress() or 0
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_count": batch_num,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
"redis_progress": final_progress,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "

View File

@@ -1,6 +1,8 @@
import json
import os
import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
@@ -383,10 +385,23 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
def get_current_tz_offset() -> int:
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
# remove tzinfo to compare non-timezone-aware objects.
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
return round(time_diff.total_seconds() / 3600)
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
# For the default value, we assume that the user's local timezone is more likely to be
# correct (i.e. the configured user's timezone or the default server one) than UTC.
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
@@ -677,3 +692,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)

View File

@@ -79,6 +79,8 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
]
)
ONE_HOUR = 3600
class ConfluenceConnector(
LoadConnector,
@@ -429,7 +431,17 @@ class ConfluenceConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
try:
return self._fetch_document_batches(start, end)
except Exception as e:
if "field 'updated' is invalid" in str(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
)
return self._fetch_document_batches(start - ONE_HOUR, end)
raise
def retrieve_all_slim_documents(
self,

View File

@@ -2,11 +2,11 @@ import copy
import threading
from collections.abc import Callable
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any
from typing import cast
from typing import Protocol
from urllib.parse import urlparse
@@ -58,13 +58,13 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
@@ -461,6 +461,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
DriveRetrievalStage.MY_DRIVE_FILES,
)
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
@@ -496,7 +497,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
yield from _yield_from_drive(drive_id, start)
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
def _yield_from_folder_crawl(
@@ -549,6 +550,16 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
)
# Setup initial completion map on first connector run
for email in all_org_emails:
# don't overwrite existing completion map on resuming runs
if email in checkpoint.completion_map:
continue
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
# we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
@@ -562,11 +573,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
for email in all_org_emails:
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -797,10 +803,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
return
for file in drive_files:
if file.error is not None:
if file.error is None:
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until_parent_id=file.parent_id,
)
yield file
@@ -902,118 +910,78 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[list[Document | ConnectorFailure]]:
) -> Iterator[Document | ConnectorFailure]:
try:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
def _yield_batch(
files_batch: list[GoogleDriveFileType],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
]
continue
files_batch.append(retrieved_file.drive_file)
docs_and_failures = [result for result in results if result is not None]
if len(files_batch) < self.batch_size:
continue
if docs_and_failures:
yield from docs_and_failures
batches_complete += 1
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
if documents:
yield documents
batches_complete += 1
files_batch = []
continue
files_batch.append(retrieved_file.drive_file)
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
if len(files_batch) < self.batch_size:
continue
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
yield from _yield_batch(files_batch)
files_batch = []
if documents:
yield documents
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
# Process any remaining files
if files_batch:
yield from _yield_batch(files_batch)
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@@ -1035,10 +1003,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
try:
for doc_list in self._extract_docs_from_google_drive(
checkpoint, start, end
):
yield from doc_list
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
@@ -1073,9 +1038,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(

View File

@@ -123,7 +123,7 @@ def crawl_folders_for_files(
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}")
logger.info(f"Found file: {file['name']}, user email: {user_email}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,

View File

@@ -175,9 +175,12 @@ def _get_tickets_page(
)
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
def _fetch_author(
client: ZendeskClient, author_id: str | int
) -> BasicExpertInfo | None:
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
# cast to str to avoid issues with zendesk changing their types
if not author_id or str(author_id) == "-1":
return None
try:

View File

@@ -8,23 +8,31 @@ from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.server.documents.models import ConnectorCredentialPair
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt.id,
"status": IndexingStatus.IN_PROGRESS.value,
"cc_pair_id": index_attempt.connector_credential_pair_id,
"search_settings_id": index_attempt.search_settings_id,
},
)
except Exception:
db_session.rollback()
raise
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.SUCCESS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.CANCELED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -283,6 +342,20 @@ def mark_attempt_failed(
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": failure_reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,

View File

@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
)
if explicit_override:
return
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(

View File

@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
prefix="/auth",
)
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,
fastapi_users.get_refresh_router(auth_backend),
prefix="/auth",
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)

View File

@@ -15,7 +15,6 @@ from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
@@ -553,8 +552,7 @@ def handle_followup_resolved_button(
# Delete the message with the option to mark resolved
if not immediate:
slack_call = make_slack_api_rate_limited(client.web_client.chat_delete)
response = slack_call(
response = client.web_client.chat_delete(
channel=channel_id,
ts=message_ts,
)

View File

@@ -18,6 +18,9 @@ from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry import RateLimitErrorRetryHandler
from slack_sdk.http_retry import RetryHandler
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
@@ -944,10 +947,21 @@ def _get_socket_client(
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.onyx.app/slack_bot_setup
# use the retry handlers built into the slack sdk
connection_error_retry_handler = ConnectionErrorRetryHandler()
rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7)
slack_retry_handlers: list[RetryHandler] = [
connection_error_retry_handler,
rate_limit_error_retry_handler,
]
return TenantSocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
web_client=WebClient(
token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers
),
tenant_id=tenant_id,
slack_bot_id=slack_bot_id,
)

View File

@@ -30,7 +30,6 @@ from onyx.configs.onyxbot_configs import (
from onyx.configs.onyxbot_configs import (
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
)
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
@@ -125,13 +124,18 @@ def update_emote_react(
)
return
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
if remove:
client.reactions_remove(
name=emoji,
channel=channel,
timestamp=message_ts,
)
else:
client.reactions_add(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
logger.error(f"Failed to remove Reaction due to: {e}")
@@ -200,9 +204,8 @@ def respond_in_thread_or_channel(
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
try:
response = slack_call(
response = client.chat_postMessage(
channel=channel,
text=text,
blocks=blocks,
@@ -224,7 +227,7 @@ def respond_in_thread_or_channel(
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
response = client.chat_postMessage(
channel=channel,
text=text,
blocks=blocks_without_urls,
@@ -236,11 +239,9 @@ def respond_in_thread_or_channel(
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
response = client.chat_postEphemeral(
channel=channel,
user=receiver,
text=text,
@@ -263,7 +264,7 @@ def respond_in_thread_or_channel(
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
response = client.chat_postEphemeral(
channel=channel,
user=receiver,
text=text,
@@ -500,7 +501,7 @@ def fetch_user_semantic_id_from_id(
if not user_id:
return None
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
response = client.users_info(user=user_id)
if not response["ok"]:
return None

View File

@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
# just gets the version of Onyx (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/refresh", {"POST"}),
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),

View File

@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
class UserRoleUpdateRequest(BaseModel):
user_email: str
new_role: UserRole
explicit_override: bool = False
class UserRoleResponse(BaseModel):

View File

@@ -102,6 +102,7 @@ def set_user_role(
validate_user_role_update(
requested_role=requested_role,
current_role=current_role,
explicit_override=user_role_update_request.explicit_override,
)
if user_to_update.id == current_user.id:
@@ -122,6 +123,22 @@ def set_user_role(
db_session.commit()
class TestUpsertRequest(BaseModel):
email: str
@router.post("/manage/users/test-upsert-user")
async def test_upsert_user(
request: TestUpsertRequest,
_: User = Depends(current_admin_user),
) -> None | FullUserSnapshot:
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
user = await fetch_ee_implementation_or_noop(
"onyx.server.saml", "upsert_saml_user", None
)(email=request.email)
return FullUserSnapshot.from_user_model(user) if user else None
@router.get("/manage/users/accepted")
def list_accepted_users(
q: str | None = Query(default=None),

View File

@@ -36,6 +36,10 @@ class RecordType(str, Enum):
LATENCY = "latency"
FAILURE = "failure"
METRIC = "metric"
INDEXING_PROGRESS = "indexing_progress"
INDEXING_COMPLETE = "indexing_complete"
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
INDEX_ATTEMPT_STATUS = "index_attempt_status"
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:

View File

@@ -6,14 +6,17 @@ import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
@@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
@@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
max_workers: Max number of worker threads
Returns:
dict: A dictionary mapping function names to their results or error messages.
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
@@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@@ -288,7 +298,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@@ -304,9 +314,9 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return task
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R:

View File

@@ -123,10 +123,15 @@ class UserManager:
user_to_set: DATestUser,
target_role: UserRole,
user_performing_action: DATestUser,
explicit_override: bool = False,
) -> DATestUser:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/set-user-role",
json={"user_email": user_to_set.email, "new_role": target_role.value},
json={
"user_email": user_to_set.email,
"new_role": target_role.value,
"explicit_override": explicit_override,
},
headers=user_performing_action.headers,
)
response.raise_for_status()

View File

@@ -0,0 +1,90 @@
import requests
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_saml_user_conversion(reset: None) -> None:
"""
Test that SAML login correctly converts users with non-authenticated roles
(SLACK_USER or EXT_PERM_USER) to authenticated roles (BASIC).
This test:
1. Creates an admin and a regular user
2. Changes the regular user's role to EXT_PERM_USER
3. Simulates a SAML login by calling the test endpoint
4. Verifies the user's role is converted to BASIC
This tests the fix that ensures users with non-authenticated roles (SLACK_USER or EXT_PERM_USER)
are properly converted to authenticated roles during SAML login.
"""
# Create an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
# Create a regular user that we'll convert to EXT_PERM_USER
test_user_email = "ext_perm_user@example.com"
test_user = UserManager.create(email=test_user_email)
# Verify the user was created with BASIC role initially
assert UserManager.is_role(test_user, UserRole.BASIC)
# Change the user's role to EXT_PERM_USER using the UserManager
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify the user has EXT_PERM_USER role now
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login by calling the test endpoint
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": test_user_email},
headers=admin_user.headers, # Use admin headers for authorization
)
response.raise_for_status()
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify user role was changed in the database
assert UserManager.is_role(test_user, UserRole.BASIC)
# Do the same test with SLACK_USER
slack_user_email = "slack_user@example.com"
slack_user = UserManager.create(email=slack_user_email)
# Verify the user was created with BASIC role initially
assert UserManager.is_role(slack_user, UserRole.BASIC)
# Change the user's role to SLACK_USER
UserManager.set_role(
user_to_set=slack_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify the user has SLACK_USER role
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
# Simulate SAML login again
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": slack_user_email},
headers=admin_user.headers,
)
response.raise_for_status()
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify the user's role was changed in the database
assert UserManager.is_role(slack_user, UserRole.BASIC)

View File

@@ -0,0 +1,43 @@
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import pytest
from onyx.db.models import OAuthAccount
from onyx.db.models import User
@pytest.fixture
def mock_user() -> MagicMock:
"""Creates a mock User instance for testing."""
user = MagicMock(spec=User)
user.email = "test@example.com"
user.id = "test-user-id"
return user
@pytest.fixture
def mock_oauth_account() -> MagicMock:
"""Creates a mock OAuthAccount instance for testing."""
oauth_account = MagicMock(spec=OAuthAccount)
oauth_account.oauth_name = "google"
oauth_account.refresh_token = "test-refresh-token"
oauth_account.access_token = "test-access-token"
oauth_account.expires_at = None
return oauth_account
@pytest.fixture
def mock_user_manager() -> MagicMock:
"""Creates a mock user manager for testing."""
user_manager = MagicMock()
user_manager.user_db = MagicMock()
user_manager.user_db.update_oauth_account = AsyncMock()
user_manager.user_db.update = AsyncMock()
return user_manager
@pytest.fixture
def mock_db_session() -> MagicMock:
"""Creates a mock database session for testing."""
return MagicMock()

View File

@@ -0,0 +1,273 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.oauth_refresher import _test_expire_oauth_token
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
from onyx.auth.oauth_refresher import refresh_oauth_token
from onyx.db.models import OAuthAccount
@pytest.mark.asyncio
async def test_refresh_oauth_token_success(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test successful OAuth token refresh."""
# Mock HTTP client and response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new_token",
"refresh_token": "new_refresh_token",
"expires_in": 3600,
}
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Use fixture values but ensure refresh token exists
mock_oauth_account.oauth_name = (
"google" # Ensure it's google to match the refresh endpoint
)
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_instance = mock_client
client_class_mock.return_value.__aenter__.return_value = client_instance
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is True
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify token data was updated correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert update_data["access_token"] == "new_token"
assert update_data["refresh_token"] == "new_refresh_token"
assert "expires_at" in update_data
@pytest.mark.asyncio
async def test_refresh_oauth_token_failure(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> bool:
"""Test OAuth token refresh failure due to HTTP error."""
# Mock HTTP client with error response
mock_response = MagicMock()
mock_response.status_code = 400 # Simulate error
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Ensure refresh token exists and provider is supported
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_class_mock.return_value.__aenter__.return_value = mock_client
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_not_called()
return True
@pytest.mark.asyncio
async def test_refresh_oauth_token_no_refresh_token(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test OAuth token refresh when no refresh token is available."""
# Set refresh token to None
mock_oauth_account.refresh_token = None
mock_oauth_account.oauth_name = "google"
# No need to mock httpx since it shouldn't be called
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
@pytest.mark.asyncio
async def test_check_and_refresh_oauth_tokens(
mock_user: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test checking and refreshing multiple OAuth tokens."""
# Create mock user with OAuth accounts
now_timestamp = datetime.now(timezone.utc).timestamp()
# Create an account that needs refreshing (expiring soon)
expiring_account = MagicMock(spec=OAuthAccount)
expiring_account.oauth_name = "google"
expiring_account.refresh_token = "refresh_token_1"
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
# Create an account that doesn't need refreshing (expires later)
valid_account = MagicMock(spec=OAuthAccount)
valid_account.oauth_name = "google"
valid_account.refresh_token = "refresh_token_2"
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
# Create an account without a refresh token
no_refresh_account = MagicMock(spec=OAuthAccount)
no_refresh_account.oauth_name = "google"
no_refresh_account.refresh_token = None
no_refresh_account.expires_at = (
now_timestamp + 60
) # Expiring soon but no refresh token
# Set oauth_accounts on the mock user
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
# Mock refresh_oauth_token function
with patch(
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
) as mock_refresh:
# Call the function under test
await check_and_refresh_oauth_tokens(
mock_user, mock_db_session, mock_user_manager
)
# Assertions
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
# Check it was called with the expiring account
mock_refresh.assert_called_once_with(
mock_user, expiring_account, mock_db_session, mock_user_manager
)
@pytest.mark.asyncio
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
"""Test identifying OAuth accounts that need refresh tokens."""
# Create accounts with and without refresh tokens
account_with_token = MagicMock(spec=OAuthAccount)
account_with_token.oauth_name = "google"
account_with_token.refresh_token = "refresh_token"
account_without_token = MagicMock(spec=OAuthAccount)
account_without_token.oauth_name = "google"
account_without_token.refresh_token = None
second_account_without_token = MagicMock(spec=OAuthAccount)
second_account_without_token.oauth_name = "github"
second_account_without_token.refresh_token = (
"" # Empty string should also be treated as missing
)
# Set accounts on user
mock_user.oauth_accounts = [
account_with_token,
account_without_token,
second_account_without_token,
]
# Call the function under test
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
mock_user
)
# Assertions
assert len(accounts_needing_refresh) == 2
assert account_without_token in accounts_needing_refresh
assert second_account_without_token in accounts_needing_refresh
assert account_with_token not in accounts_needing_refresh
@pytest.mark.asyncio
async def test_check_oauth_account_has_refresh_token(
mock_user: MagicMock, mock_oauth_account: MagicMock
) -> None:
"""Test checking if an OAuth account has a refresh token."""
# Test with refresh token
mock_oauth_account.refresh_token = "refresh_token"
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is True
# Test with None refresh token
mock_oauth_account.refresh_token = None
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
# Test with empty string refresh token
mock_oauth_account.refresh_token = ""
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
@pytest.mark.asyncio
async def test_test_expire_oauth_token(
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test the testing utility function for token expiration."""
# Set up the mock account
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "test_refresh_token"
mock_oauth_account.access_token = "test_access_token"
# Call the function under test
result = await _test_expire_oauth_token(
mock_user,
mock_oauth_account,
mock_db_session,
mock_user_manager,
expire_in_seconds=10,
)
# Assertions
assert result is True
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify the expiration time was set correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert "expires_at" in update_data
# Now should be within 10-11 seconds of the set expiration
now = datetime.now(timezone.utc).timestamp()
assert update_data["expires_at"] - now >= 8.9 # Allow 1 second for test execution
assert update_data["expires_at"] - now <= 11.1 # Allow 1 second for test execution

View File

@@ -89,7 +89,8 @@ def test_run_in_background_and_wait_success() -> None:
elapsed = time.time() - start_time
assert result == 42
assert elapsed >= 0.1 # Verify we actually waited for the sleep
# sometimes slightly flaky
assert elapsed >= 0.095 # Verify we actually waited for the sleep
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")

View File

@@ -5,7 +5,7 @@ envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/con
echo "Waiting for API server to boot up; this may take a minute or two..."
echo "If this takes more than ~5 minutes, check the logs of the API server container for errors with the following command:"
echo
echo "docker logs onyx-stack_api_server-1"
echo "docker logs onyx-stack-api_server-1"
echo
while true; do

1
openapi.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1079,7 +1079,7 @@ export function AssistantEditor({
</Tooltip>
</TooltipProvider>
<span className="text-sm ml-2">
{values.is_public ? "Public" : "Private"}
Organization Public
</span>
</div>
@@ -1088,17 +1088,22 @@ export function AssistantEditor({
<InfoIcon size={16} className="mr-2" />
<span className="text-sm">
Default persona must be public. Visibility has been
automatically set to public.
automatically set to organization public.
</span>
</div>
)}
{values.is_public ? (
<p className="text-sm text-text-dark">
Anyone from your team can view and use this assistant
This assistant will be available to everyone in your
organization
</p>
) : (
<>
<p className="text-sm text-text-dark mb-2">
This assistant will only be available to specific
users and groups
</p>
<div className="mt-2">
<Label className="mb-2" small>
Share with Users and Groups

View File

@@ -100,7 +100,10 @@ export function EmailPasswordForm({
// server-side provider values)
window.location.href = "/auth/waiting-on-verification";
} else {
// See above comment
// The searchparam is purely for multi tenant developement purposes.
// It replicates the behavior of the case where a user
// has signed up with email / password as the only user to an instance
// and has just completed verification
window.location.href = nextUrl
? encodeURI(nextUrl)
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;

View File

@@ -7,7 +7,7 @@ import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
import { User } from "@/lib/types";
import { Logo } from "@/components/logo/Logo";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
export function Verify({ user }: { user: User | null }) {
const searchParams = useSearchParams();
const router = useRouter();
@@ -16,6 +16,8 @@ export function Verify({ user }: { user: User | null }) {
const verify = useCallback(async () => {
const token = searchParams.get("token");
const firstUser =
searchParams.get("first_user") && NEXT_PUBLIC_CLOUD_ENABLED;
if (!token) {
setError(
"Missing verification token. Try requesting a new verification email."
@@ -35,7 +37,7 @@ export function Verify({ user }: { user: User | null }) {
// Use window.location.href to force a full page reload,
// ensuring app re-initializes with the new state (including
// server-side provider values)
window.location.href = "/";
window.location.href = firstUser ? "/chat?new_team=true" : "/chat";
} else {
const errorDetail = (await response.json()).detail;
setError(

View File

@@ -1158,6 +1158,7 @@ export function ChatPage({
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
setUncaughtError(null);
setLoadingError(null);
// Mark that we've sent a message for this session in the current page load
markSessionMessageSent(frozenSessionId);

View File

@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
import { unstable_noStore as noStore } from "next/cache";
import { fetchChatData } from "@/lib/chat/fetchChatData";
import { ChatProvider } from "@/components/context/ChatContext";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
export default async function Layout({
children,
@@ -41,7 +40,6 @@ export default async function Layout({
return (
<>
<InstantSSRAutoRefresh />
<ChatProvider
value={{
proSearchToggled,

View File

@@ -1,11 +1,18 @@
"use client";
import React, { createContext, useContext, useState, useEffect } from "react";
import React, {
createContext,
useContext,
useState,
useEffect,
useRef,
} from "react";
import { User, UserRole } from "@/lib/types";
import { getCurrentUser } from "@/lib/user";
import { usePostHog } from "posthog-js/react";
import { CombinedSettings } from "@/app/admin/settings/interfaces";
import { SettingsContext } from "../settings/SettingsProvider";
import { useTokenRefresh } from "@/hooks/useTokenRefresh";
interface UserContextType {
user: User | null;
@@ -93,6 +100,10 @@ export function UserProvider({
console.error("Error fetching current user:", error);
}
};
// Use the custom token refresh hook
useTokenRefresh(upToDateUser, fetchUser);
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
try {
setUpToDateUser((prevUser) => {

View File

@@ -0,0 +1,84 @@
import { useState, useEffect, useRef } from "react";
import { User } from "@/lib/types";
import { NO_AUTH_USER_ID } from "@/lib/extension/constants";
// Refresh token every 10 minutes (600000ms)
// This is shorter than the session expiry time to ensure tokens stay valid
const REFRESH_INTERVAL = 600000;
// Custom hook for handling JWT token refresh for current user
export function useTokenRefresh(
user: User | null,
onRefreshFail: () => Promise<void>
) {
// Track last refresh time to avoid unnecessary calls
const [lastTokenRefresh, setLastTokenRefresh] = useState<number>(Date.now());
// Use a ref to track first load
const isFirstLoad = useRef(true);
useEffect(() => {
if (!user || user.id === NO_AUTH_USER_ID) return;
const refreshTokenPeriodically = async () => {
try {
// Skip time check if this is first load - we always refresh on first load
const isTimeToRefresh =
isFirstLoad.current ||
Date.now() - lastTokenRefresh > REFRESH_INTERVAL - 60000;
if (!isTimeToRefresh) {
return;
}
// Reset first load flag
if (isFirstLoad.current) {
isFirstLoad.current = false;
}
const response = await fetch("/api/auth/refresh", {
method: "POST",
credentials: "include",
});
if (response.ok) {
// Update last refresh time on success
setLastTokenRefresh(Date.now());
console.debug("Auth token refreshed successfully");
} else {
console.warn("Failed to refresh auth token:", response.status);
// If token refresh fails, try to get current user info
await onRefreshFail();
}
} catch (error) {
console.error("Error refreshing auth token:", error);
}
};
// Always attempt to refresh on first component mount
// This helps ensure tokens are fresh, especially after browser refresh
refreshTokenPeriodically();
// Set up interval for periodic refreshes
const intervalId = setInterval(refreshTokenPeriodically, REFRESH_INTERVAL);
// Also refresh token on window focus, but no more than once per minute
const handleVisibilityChange = () => {
if (
document.visibilityState === "visible" &&
Date.now() - lastTokenRefresh > 60000
) {
refreshTokenPeriodically();
}
};
document.addEventListener("visibilitychange", handleVisibilityChange);
return () => {
clearInterval(intervalId);
document.removeEventListener("visibilitychange", handleVisibilityChange);
};
}, [user, lastTokenRefresh, onRefreshFail]);
return { lastTokenRefresh };
}

View File

@@ -35,3 +35,5 @@ export const LocalStorageKeys = {
export const SEARCH_PARAMS = {
DEFAULT_SIDEBAR_OFF: "defaultSidebarOff",
};
export const NO_AUTH_USER_ID = "__no_auth_user__";