mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
21 Commits
bugfixfix
...
prevent_is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63d6931bf9 | ||
|
|
bc9b4e4f45 | ||
|
|
178a64f298 | ||
|
|
c79f1edf1d | ||
|
|
7c8e23aa54 | ||
|
|
d37b427d52 | ||
|
|
a65fefd226 | ||
|
|
bb09bde519 | ||
|
|
0f6cf0fc58 | ||
|
|
fed06b592d | ||
|
|
8d92a1524e | ||
|
|
ecfea9f5ed | ||
|
|
b269f1ba06 | ||
|
|
30c878efa5 | ||
|
|
2024776c19 | ||
|
|
431316929c | ||
|
|
c5b9c6e308 | ||
|
|
73dd188b3f | ||
|
|
79b061abbc | ||
|
|
552f1ead4f | ||
|
|
17925b49e8 |
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -28,6 +28,20 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
|
||||
@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
# Ensure we request offline_access for refresh tokens
|
||||
try:
|
||||
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
|
||||
if "offline_access" not in oidc_scopes:
|
||||
oidc_scopes.append("offline_access")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error configuring OIDC scopes: {e}")
|
||||
# Fall back to default scopes if there's an error
|
||||
oidc_scopes = BASE_SCOPES
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
# Use the configured scopes
|
||||
base_scopes=oidc_scopes,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Define non-authenticated user roles that should be re-created during SAML login
|
||||
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
return await user_manager.get_by_email(email)
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if user.role in NON_AUTHENTICATED_ROLES:
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
logger.notice("Creating user from SAML login")
|
||||
logger.info("Creating user from SAML login")
|
||||
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
|
||||
user: User = await user_manager.create(
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,10 +16,10 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -62,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
width: 140px;
|
||||
height: auto;
|
||||
filter: brightness(1.1) contrast(1.2);
|
||||
border-radius: 8px;
|
||||
padding: 5px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
@@ -78,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
padding: 14px 24px;
|
||||
background-color: #0055FF;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
margin-top: 10px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
text-align: center;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
@@ -166,6 +175,7 @@ def send_email(
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
@@ -174,17 +184,30 @@ def send_email(
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
# Add text part first (lowest priority)
|
||||
text_part = MIMEText(text_body, "plain")
|
||||
msg.attach(text_part)
|
||||
|
||||
if inline_png:
|
||||
# For HTML with images, create a multipart/related container
|
||||
related = MIMEMultipart("related")
|
||||
|
||||
# Add the HTML part to the related container
|
||||
html_part = MIMEText(html_body, "html")
|
||||
related.attach(html_part)
|
||||
|
||||
# Add image with proper Content-ID to the related container
|
||||
img = MIMEImage(inline_png[1], _subtype="png")
|
||||
img.add_header("Content-ID", inline_png[0]) # CID reference
|
||||
img.add_header("Content-ID", f"<{inline_png[0]}>")
|
||||
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
|
||||
msg.attach(img)
|
||||
related.attach(img)
|
||||
|
||||
# Add the related part to the message (higher priority than text)
|
||||
msg.attach(related)
|
||||
else:
|
||||
# No images, just add HTML directly (higher priority than text)
|
||||
html_part = MIMEText(html_body, "html")
|
||||
msg.attach(html_part)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -332,17 +355,23 @@ def send_forgot_password_email(
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"{application_name} Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
subject = f"Reset Your {application_name} Password"
|
||||
heading = "Reset Your Password"
|
||||
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
|
||||
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
|
||||
cta_text = "Reset Password"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
"Reset Your Password",
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
text_content = (
|
||||
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
|
||||
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
@@ -356,6 +385,7 @@ def send_forgot_password_email(
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
new_organization: bool = False,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
@@ -372,6 +402,8 @@ def send_user_verification_email(
|
||||
|
||||
subject = f"{application_name} Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
if new_organization:
|
||||
link = add_url_params(link, {"first_user": "true"})
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
|
||||
211
backend/onyx/auth/oauth_refresher.py
Normal file
211
backend/onyx/auth/oauth_refresher.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Standard OAuth refresh token endpoints
|
||||
REFRESH_ENDPOINTS = {
|
||||
"google": "https://oauth2.googleapis.com/token",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Keeping this as a utility function for potential future debugging,
|
||||
# but not using it in production code
|
||||
async def _test_expire_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
expire_in_seconds: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
Utility function for testing - Sets an OAuth token to expire in a short time
|
||||
to facilitate testing of the refresh flow.
|
||||
Not used in production code.
|
||||
"""
|
||||
try:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
|
||||
)
|
||||
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error setting artificial expiration: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to refresh an OAuth token that's about to expire or has expired.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not oauth_account.refresh_token:
|
||||
logger.warning(
|
||||
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
|
||||
)
|
||||
return False
|
||||
|
||||
provider = oauth_account.oauth_name
|
||||
if provider not in REFRESH_ENDPOINTS:
|
||||
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
REFRESH_ENDPOINTS[provider],
|
||||
data={
|
||||
"client_id": OAUTH_CLIENT_ID,
|
||||
"client_secret": OAUTH_CLIENT_SECRET,
|
||||
"refresh_token": oauth_account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to refresh OAuth token: Status {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
new_access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get(
|
||||
"refresh_token", oauth_account.refresh_token
|
||||
)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
# Calculate new expiry time if provided
|
||||
new_expires_at: Optional[int] = None
|
||||
if expires_in:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expires_in)
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
updated_data: Dict[str, Any] = {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}
|
||||
|
||||
if new_expires_at:
|
||||
updated_data["expires_at"] = new_expires_at
|
||||
|
||||
# Update oidc_expiry in user model if we're tracking it
|
||||
if TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(
|
||||
new_expires_at, tz=timezone.utc
|
||||
)
|
||||
await user_manager.user_db.update(
|
||||
user, {"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error refreshing OAuth token: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_refresh_oauth_tokens(
|
||||
user: User,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check if any OAuth tokens are expired or about to expire and refresh them.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return
|
||||
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Buffer time to refresh tokens before they expire (in seconds)
|
||||
buffer_seconds = 300 # 5 minutes
|
||||
|
||||
for oauth_account in user.oauth_accounts:
|
||||
# Skip accounts without refresh tokens
|
||||
if not oauth_account.refresh_token:
|
||||
continue
|
||||
|
||||
# If token is about to expire, refresh it
|
||||
if (
|
||||
oauth_account.expires_at
|
||||
and oauth_account.expires_at - now_timestamp < buffer_seconds
|
||||
):
|
||||
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
|
||||
success = await refresh_oauth_token(
|
||||
user, oauth_account, db_session, user_manager
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Failed to refresh OAuth token. User may need to re-authenticate."
|
||||
)
|
||||
|
||||
|
||||
async def check_oauth_account_has_refresh_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an OAuth account has a refresh token.
|
||||
Returns True if a refresh token exists, False otherwise.
|
||||
"""
|
||||
return bool(oauth_account.refresh_token)
|
||||
|
||||
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
|
||||
"""
|
||||
Returns a list of OAuth accounts for a user that are missing refresh tokens.
|
||||
These accounts will need re-authentication to get refresh tokens.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return []
|
||||
|
||||
accounts_needing_refresh = []
|
||||
for oauth_account in user.oauth_accounts:
|
||||
has_refresh_token = await check_oauth_account_has_refresh_token(
|
||||
user, oauth_account
|
||||
)
|
||||
if not has_refresh_token:
|
||||
accounts_needing_refresh.append(oauth_account)
|
||||
|
||||
return accounts_needing_refresh
|
||||
@@ -5,12 +5,16 @@ import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -581,8 +585,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
logger.notice(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
user_count = await get_user_count()
|
||||
send_user_verification_email(
|
||||
user.email, token, new_organization=user_count == 1
|
||||
)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
@@ -688,16 +694,20 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
T = TypeVar("T", covariant=True)
|
||||
ID = TypeVar("ID", contravariant=True)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
# Protocol for strategies that support token refreshing without inheritance.
|
||||
class RefreshableStrategy(Protocol):
|
||||
"""Protocol for authentication strategies that support token refreshing."""
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: Any) -> str:
|
||||
"""
|
||||
Refresh an existing token by extending its lifetime.
|
||||
Returns either the same token with extended expiration or a new token.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
@@ -756,6 +766,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by extending its expiration time in Redis."""
|
||||
if token is None:
|
||||
# If no token provided, create a new one
|
||||
return await self.write_token(user)
|
||||
|
||||
redis = await get_async_redis_connection()
|
||||
token_key = f"{self.key_prefix}{token}"
|
||||
|
||||
# Check if token exists
|
||||
token_data_str = await redis.get(token_key)
|
||||
if not token_data_str:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Token exists, extend its lifetime
|
||||
token_data = json.loads(token_data_str)
|
||||
await redis.set(
|
||||
token_key,
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
|
||||
"""Database strategy with token refreshing capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token_db: AccessTokenDatabase[AccessToken],
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
):
|
||||
super().__init__(access_token_db, lifetime_seconds)
|
||||
self._access_token_db = access_token_db
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by updating its expiration time in the database."""
|
||||
if token is None:
|
||||
return await self.write_token(user)
|
||||
|
||||
# Find the token in database
|
||||
access_token = await self._access_token_db.get_by_token(token)
|
||||
|
||||
if access_token is None:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Update expiration time
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
|
||||
)
|
||||
await self._access_token_db.update(access_token, {"expires": new_expires})
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_redis_strategy() -> TenantAwareRedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> RefreshableDatabaseStrategy:
|
||||
return RefreshableDatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
@@ -806,6 +885,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
return router
|
||||
|
||||
def get_refresh_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for session token refreshing.
|
||||
"""
|
||||
# Import the oauth_refresher here to avoid circular imports
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
refresh_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_login_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
|
||||
)
|
||||
async def refresh(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
|
||||
get_user_manager
|
||||
),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> Response:
|
||||
try:
|
||||
user, token = user_token
|
||||
logger.info(f"Processing token refresh request for user {user.email}")
|
||||
|
||||
# Check if user has OAuth accounts that need refreshing
|
||||
await check_and_refresh_oauth_tokens(
|
||||
user=cast(User, user),
|
||||
db_session=db_session,
|
||||
user_manager=cast(Any, user_manager),
|
||||
)
|
||||
|
||||
# Check if strategy supports refreshing
|
||||
supports_refresh = hasattr(strategy, "refresh_token") and callable(
|
||||
getattr(strategy, "refresh_token")
|
||||
)
|
||||
|
||||
if supports_refresh:
|
||||
try:
|
||||
refresh_method = getattr(strategy, "refresh_token")
|
||||
new_token = await refresh_method(token, user)
|
||||
logger.info(
|
||||
f"Successfully refreshed session token for user {user.email}"
|
||||
)
|
||||
return await backend.transport.get_login_response(new_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing session token: {str(e)}")
|
||||
# Fallback to logout and login if refresh fails
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
|
||||
# Fallback: logout and login again
|
||||
logger.info(
|
||||
"Strategy doesn't support refresh - using logout/login flow"
|
||||
)
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Token refresh failed: {str(e)}",
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
@@ -1039,12 +1200,20 @@ def get_oauth_router(
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
authorization_url = add_url_params(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
|
||||
@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -235,7 +236,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
|
||||
@@ -451,6 +451,8 @@ def monitor_connector_deletion_taskset(
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -68,6 +68,8 @@ from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -875,6 +877,21 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
|
||||
data={
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"id": payload.id if payload else None,
|
||||
"total_docs": initial if initial is not None else 0,
|
||||
"remaining_docs": remaining,
|
||||
"synced_docs": (initial - remaining) if initial is not None else 0,
|
||||
"is_complete": remaining == 0,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
|
||||
@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -570,6 +573,22 @@ def _run_indexing(
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# Add telemetry for indexing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_PROGRESS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_num": batch_num,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
@@ -585,6 +604,30 @@ def _run_indexing(
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# Add telemetry for completed indexing
|
||||
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt_start.search_settings_id
|
||||
)
|
||||
final_progress = redis_connector_index.get_progress() or 0
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_count": batch_num,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
"redis_progress": final_progress,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
@@ -383,10 +385,23 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
|
||||
def get_current_tz_offset() -> int:
|
||||
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
|
||||
# remove tzinfo to compare non-timezone-aware objects.
|
||||
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
# For the default value, we assume that the user's local timezone is more likely to be
|
||||
# correct (i.e. the configured user's timezone or the default server one) than UTC.
|
||||
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
@@ -677,3 +692,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
DISABLE_AUTO_AUTH_REFRESH = (
|
||||
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -79,6 +79,8 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
]
|
||||
)
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
|
||||
class ConfluenceConnector(
|
||||
LoadConnector,
|
||||
@@ -429,7 +431,17 @@ class ConfluenceConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
try:
|
||||
return self._fetch_document_batches(start, end)
|
||||
except Exception as e:
|
||||
if "field 'updated' is invalid" in str(e) and start is not None:
|
||||
logger.warning(
|
||||
"Confluence says we provided an invalid 'updated' field. This may indicate"
|
||||
"a real issue, but can also appear during edge cases like daylight"
|
||||
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
|
||||
)
|
||||
return self._fetch_document_batches(start - ONE_HOUR, end)
|
||||
raise
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
|
||||
@@ -2,11 +2,11 @@ import copy
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -58,13 +58,13 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -461,6 +461,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
DriveRetrievalStage.MY_DRIVE_FILES,
|
||||
)
|
||||
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
|
||||
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
||||
|
||||
@@ -496,7 +497,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
)
|
||||
yield from _yield_from_drive(drive_id, start)
|
||||
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
|
||||
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
|
||||
|
||||
def _yield_from_folder_crawl(
|
||||
@@ -549,6 +550,16 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
|
||||
)
|
||||
|
||||
# Setup initial completion map on first connector run
|
||||
for email in all_org_emails:
|
||||
# don't overwrite existing completion map on resuming runs
|
||||
if email in checkpoint.completion_map:
|
||||
continue
|
||||
checkpoint.completion_map[email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
)
|
||||
|
||||
# we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
@@ -562,11 +573,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_ids_to_retrieve, checkpoint
|
||||
)
|
||||
|
||||
for email in all_org_emails:
|
||||
checkpoint.completion_map[email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
)
|
||||
user_retrieval_gens = [
|
||||
self._impersonate_user_for_retrieval(
|
||||
email,
|
||||
@@ -797,10 +803,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
return
|
||||
|
||||
for file in drive_files:
|
||||
if file.error is not None:
|
||||
if file.error is None:
|
||||
checkpoint.completion_map[file.user_email].update(
|
||||
stage=file.completion_stage,
|
||||
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
|
||||
completed_until=datetime.fromisoformat(
|
||||
file.drive_file[GoogleFields.MODIFIED_TIME.value]
|
||||
).timestamp(),
|
||||
completed_until_parent_id=file.parent_id,
|
||||
)
|
||||
yield file
|
||||
@@ -902,118 +910,78 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[list[Document | ConnectorFailure]]:
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
try:
|
||||
# Create a larger process pool for file conversion
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[GoogleDriveFileType],
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
nonlocal batches_complete
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [(convert_func, (file,)) for file in files_batch]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if retrieved_file.error is not None:
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = (
|
||||
f"retrieval failure during stage: {failure_stage},"
|
||||
)
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += (
|
||||
f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
]
|
||||
continue
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
docs_and_failures = [result for result in results if result is not None]
|
||||
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
if docs_and_failures:
|
||||
yield from docs_and_failures
|
||||
batches_complete += 1
|
||||
|
||||
# Process the batch
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file: {e}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=retrieved_file.drive_file["id"],
|
||||
document_link=retrieved_file.drive_file[
|
||||
"webViewLink"
|
||||
],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if retrieved_file.error is not None:
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = (
|
||||
f"retrieval failure during stage: {failure_stage},"
|
||||
)
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += (
|
||||
f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
batches_complete += 1
|
||||
files_batch = []
|
||||
continue
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
return # create a new checkpoint
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file: {e}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=retrieved_file.drive_file["id"],
|
||||
document_link=retrieved_file.drive_file[
|
||||
"webViewLink"
|
||||
],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
yield from _yield_batch(files_batch)
|
||||
files_batch = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
return # create a new checkpoint
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
yield from _yield_batch(files_batch)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
raise e
|
||||
@@ -1035,10 +1003,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
for doc_list in self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end
|
||||
):
|
||||
yield from doc_list
|
||||
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
@@ -1073,9 +1038,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
raise RuntimeError(
|
||||
"_extract_slim_docs_from_google_drive: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
|
||||
@@ -123,7 +123,7 @@ def crawl_folders_for_files(
|
||||
end=end,
|
||||
):
|
||||
found_files = True
|
||||
logger.info(f"Found file: {file['name']}")
|
||||
logger.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
|
||||
@@ -175,9 +175,12 @@ def _get_tickets_page(
|
||||
)
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
def _fetch_author(
|
||||
client: ZendeskClient, author_id: str | int
|
||||
) -> BasicExpertInfo | None:
|
||||
# Skip fetching if author_id is invalid
|
||||
if not author_id or author_id == "-1":
|
||||
# cast to str to avoid issues with zendesk changing their types
|
||||
if not author_id or str(author_id) == "-1":
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
@@ -8,23 +8,31 @@ from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.server.documents.models import ConnectorCredentialPair
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
|
||||
# from onyx.db.engine import async_query_for_dms
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
|
||||
attempt.status = IndexingStatus.IN_PROGRESS
|
||||
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt.id,
|
||||
"status": IndexingStatus.IN_PROGRESS.value,
|
||||
"cc_pair_id": index_attempt.connector_credential_pair_id,
|
||||
"search_settings_id": index_attempt.search_settings_id,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
|
||||
|
||||
attempt.status = IndexingStatus.SUCCESS
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.SUCCESS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
|
||||
|
||||
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
|
||||
attempt.status = IndexingStatus.CANCELED
|
||||
attempt.error_msg = reason
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.CANCELED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -283,6 +342,20 @@ def mark_attempt_failed(
|
||||
attempt.error_msg = failure_reason
|
||||
attempt.full_exception_trace = full_exception_trace
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.FAILED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": failure_reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
with get_session_context_manager() as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_latest_index_attempts(
|
||||
secondary_index,
|
||||
db_session,
|
||||
|
||||
@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
Assumed only admins can hit this endpoint.
|
||||
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
|
||||
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if explicit_override:
|
||||
return
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
|
||||
@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
# Add refresh token endpoint for OAuth as well
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_refresh_router(auth_backend),
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
RequestValidationError, validation_exception_handler
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
@@ -553,8 +552,7 @@ def handle_followup_resolved_button(
|
||||
|
||||
# Delete the message with the option to mark resolved
|
||||
if not immediate:
|
||||
slack_call = make_slack_api_rate_limited(client.web_client.chat_delete)
|
||||
response = slack_call(
|
||||
response = client.web_client.chat_delete(
|
||||
channel=channel_id,
|
||||
ts=message_ts,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,9 @@ from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RateLimitErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -944,10 +947,21 @@ def _get_socket_client(
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
# use the retry handlers built into the slack sdk
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7)
|
||||
slack_retry_handlers: list[RetryHandler] = [
|
||||
connection_error_retry_handler,
|
||||
rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
return TenantSocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
web_client=WebClient(
|
||||
token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
slack_bot_id=slack_bot_id,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.configs.onyxbot_configs import (
|
||||
from onyx.configs.onyxbot_configs import (
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
|
||||
)
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -125,13 +124,18 @@ def update_emote_react(
|
||||
)
|
||||
return
|
||||
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
if remove:
|
||||
client.reactions_remove(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
else:
|
||||
client.reactions_add(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if remove:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
@@ -200,9 +204,8 @@ def respond_in_thread_or_channel(
|
||||
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
try:
|
||||
response = slack_call(
|
||||
response = client.chat_postMessage(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
@@ -224,7 +227,7 @@ def respond_in_thread_or_channel(
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
response = client.chat_postMessage(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
@@ -236,11 +239,9 @@ def respond_in_thread_or_channel(
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
response = client.chat_postEphemeral(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
@@ -263,7 +264,7 @@ def respond_in_thread_or_channel(
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
response = client.chat_postEphemeral(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
@@ -500,7 +501,7 @@ def fetch_user_semantic_id_from_id(
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
response = client.users_info(user=user_id)
|
||||
if not response["ok"]:
|
||||
return None
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
# just gets the version of Onyx (e.g. 0.3.11)
|
||||
("/version", {"GET"}),
|
||||
# stuff related to basic auth
|
||||
("/auth/refresh", {"POST"}),
|
||||
("/auth/register", {"POST"}),
|
||||
("/auth/login", {"POST"}),
|
||||
("/auth/logout", {"POST"}),
|
||||
|
||||
@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
|
||||
class UserRoleUpdateRequest(BaseModel):
|
||||
user_email: str
|
||||
new_role: UserRole
|
||||
explicit_override: bool = False
|
||||
|
||||
|
||||
class UserRoleResponse(BaseModel):
|
||||
|
||||
@@ -102,6 +102,7 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
if user_to_update.id == current_user.id:
|
||||
@@ -122,6 +123,22 @@ def set_user_role(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestUpsertRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
@router.post("/manage/users/test-upsert-user")
|
||||
async def test_upsert_user(
|
||||
request: TestUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None | FullUserSnapshot:
|
||||
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
|
||||
user = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.saml", "upsert_saml_user", None
|
||||
)(email=request.email)
|
||||
return FullUserSnapshot.from_user_model(user) if user else None
|
||||
|
||||
|
||||
@router.get("/manage/users/accepted")
|
||||
def list_accepted_users(
|
||||
q: str | None = Query(default=None),
|
||||
|
||||
@@ -36,6 +36,10 @@ class RecordType(str, Enum):
|
||||
LATENCY = "latency"
|
||||
FAILURE = "failure"
|
||||
METRIC = "metric"
|
||||
INDEXING_PROGRESS = "indexing_progress"
|
||||
INDEXING_COMPLETE = "indexing_complete"
|
||||
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
|
||||
INDEX_ATTEMPT_STATUS = "index_attempt_status"
|
||||
|
||||
|
||||
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
|
||||
@@ -6,14 +6,17 @@ import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import FIRST_COMPLETED
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import wait
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import overload
|
||||
from typing import Protocol
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
@@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
This function preserves contextvars across threads, which is important for maintaining
|
||||
context like tenant IDs in database sessions.
|
||||
|
||||
Args:
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
@@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
|
||||
max_workers: Max number of worker threads
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping function names to their results or error messages.
|
||||
list: A list of results from each function, in the same order as the input functions.
|
||||
"""
|
||||
workers = (
|
||||
min(max_workers, len(functions_with_args))
|
||||
@@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logger.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None))
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
@@ -288,7 +298,7 @@ def run_with_timeout(
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
return task.result # type: ignore
|
||||
|
||||
|
||||
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
|
||||
@@ -304,9 +314,9 @@ def run_in_background(
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
# Timeout not used in the non-blocking case
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
|
||||
task.start()
|
||||
return task
|
||||
return cast(TimeoutThread[R], task)
|
||||
|
||||
|
||||
def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
|
||||
@@ -123,10 +123,15 @@ class UserManager:
|
||||
user_to_set: DATestUser,
|
||||
target_role: UserRole,
|
||||
user_performing_action: DATestUser,
|
||||
explicit_override: bool = False,
|
||||
) -> DATestUser:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/set-user-role",
|
||||
json={"user_email": user_to_set.email, "new_role": target_role.value},
|
||||
json={
|
||||
"user_email": user_to_set.email,
|
||||
"new_role": target_role.value,
|
||||
"explicit_override": explicit_override,
|
||||
},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_saml_user_conversion(reset: None) -> None:
|
||||
"""
|
||||
Test that SAML login correctly converts users with non-authenticated roles
|
||||
(SLACK_USER or EXT_PERM_USER) to authenticated roles (BASIC).
|
||||
|
||||
This test:
|
||||
1. Creates an admin and a regular user
|
||||
2. Changes the regular user's role to EXT_PERM_USER
|
||||
3. Simulates a SAML login by calling the test endpoint
|
||||
4. Verifies the user's role is converted to BASIC
|
||||
|
||||
This tests the fix that ensures users with non-authenticated roles (SLACK_USER or EXT_PERM_USER)
|
||||
are properly converted to authenticated roles during SAML login.
|
||||
"""
|
||||
# Create an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
|
||||
|
||||
# Create a regular user that we'll convert to EXT_PERM_USER
|
||||
test_user_email = "ext_perm_user@example.com"
|
||||
test_user = UserManager.create(email=test_user_email)
|
||||
|
||||
# Verify the user was created with BASIC role initially
|
||||
assert UserManager.is_role(test_user, UserRole.BASIC)
|
||||
|
||||
# Change the user's role to EXT_PERM_USER using the UserManager
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify the user has EXT_PERM_USER role now
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
assert UserManager.is_role(test_user, UserRole.BASIC)
|
||||
|
||||
# Do the same test with SLACK_USER
|
||||
slack_user_email = "slack_user@example.com"
|
||||
slack_user = UserManager.create(email=slack_user_email)
|
||||
|
||||
# Verify the user was created with BASIC role initially
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
# Change the user's role to SLACK_USER
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify the user has SLACK_USER role
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
43
backend/tests/unit/onyx/auth/conftest.py
Normal file
43
backend/tests/unit/onyx/auth/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user() -> MagicMock:
|
||||
"""Creates a mock User instance for testing."""
|
||||
user = MagicMock(spec=User)
|
||||
user.email = "test@example.com"
|
||||
user.id = "test-user-id"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_account() -> MagicMock:
|
||||
"""Creates a mock OAuthAccount instance for testing."""
|
||||
oauth_account = MagicMock(spec=OAuthAccount)
|
||||
oauth_account.oauth_name = "google"
|
||||
oauth_account.refresh_token = "test-refresh-token"
|
||||
oauth_account.access_token = "test-access-token"
|
||||
oauth_account.expires_at = None
|
||||
return oauth_account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_manager() -> MagicMock:
|
||||
"""Creates a mock user manager for testing."""
|
||||
user_manager = MagicMock()
|
||||
user_manager.user_db = MagicMock()
|
||||
user_manager.user_db.update_oauth_account = AsyncMock()
|
||||
user_manager.user_db.update = AsyncMock()
|
||||
return user_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
"""Creates a mock database session for testing."""
|
||||
return MagicMock()
|
||||
273
backend/tests/unit/onyx/auth/test_oauth_refresher.py
Normal file
273
backend/tests/unit/onyx/auth/test_oauth_refresher.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.oauth_refresher import _test_expire_oauth_token
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
|
||||
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
|
||||
from onyx.auth.oauth_refresher import refresh_oauth_token
|
||||
from onyx.db.models import OAuthAccount
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_success(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test successful OAuth token refresh."""
|
||||
# Mock HTTP client and response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
# Create async mock for the client post method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Use fixture values but ensure refresh token exists
|
||||
mock_oauth_account.oauth_name = (
|
||||
"google" # Ensure it's google to match the refresh endpoint
|
||||
)
|
||||
mock_oauth_account.refresh_token = "old_refresh_token"
|
||||
|
||||
# Patch at the module level where it's actually being used
|
||||
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
|
||||
# Configure the context manager
|
||||
client_instance = mock_client
|
||||
client_class_mock.return_value.__aenter__.return_value = client_instance
|
||||
|
||||
# Call the function under test
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_called_once()
|
||||
|
||||
# Verify token data was updated correctly
|
||||
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
|
||||
assert update_data["access_token"] == "new_token"
|
||||
assert update_data["refresh_token"] == "new_refresh_token"
|
||||
assert "expires_at" in update_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_failure(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> bool:
|
||||
"""Test OAuth token refresh failure due to HTTP error."""
|
||||
# Mock HTTP client with error response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400 # Simulate error
|
||||
|
||||
# Create async mock for the client post method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Ensure refresh token exists and provider is supported
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
mock_oauth_account.refresh_token = "old_refresh_token"
|
||||
|
||||
# Patch at the module level where it's actually being used
|
||||
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
|
||||
# Configure the context manager
|
||||
client_class_mock.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Call the function under test
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is False
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_not_called()
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_no_refresh_token(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test OAuth token refresh when no refresh token is available."""
|
||||
# Set refresh token to None
|
||||
mock_oauth_account.refresh_token = None
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
|
||||
# No need to mock httpx since it shouldn't be called
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_and_refresh_oauth_tokens(
|
||||
mock_user: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test checking and refreshing multiple OAuth tokens."""
|
||||
# Create mock user with OAuth accounts
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Create an account that needs refreshing (expiring soon)
|
||||
expiring_account = MagicMock(spec=OAuthAccount)
|
||||
expiring_account.oauth_name = "google"
|
||||
expiring_account.refresh_token = "refresh_token_1"
|
||||
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
|
||||
|
||||
# Create an account that doesn't need refreshing (expires later)
|
||||
valid_account = MagicMock(spec=OAuthAccount)
|
||||
valid_account.oauth_name = "google"
|
||||
valid_account.refresh_token = "refresh_token_2"
|
||||
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
|
||||
|
||||
# Create an account without a refresh token
|
||||
no_refresh_account = MagicMock(spec=OAuthAccount)
|
||||
no_refresh_account.oauth_name = "google"
|
||||
no_refresh_account.refresh_token = None
|
||||
no_refresh_account.expires_at = (
|
||||
now_timestamp + 60
|
||||
) # Expiring soon but no refresh token
|
||||
|
||||
# Set oauth_accounts on the mock user
|
||||
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
|
||||
|
||||
# Mock refresh_oauth_token function
|
||||
with patch(
|
||||
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
|
||||
) as mock_refresh:
|
||||
# Call the function under test
|
||||
await check_and_refresh_oauth_tokens(
|
||||
mock_user, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
|
||||
# Check it was called with the expiring account
|
||||
mock_refresh.assert_called_once_with(
|
||||
mock_user, expiring_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
|
||||
"""Test identifying OAuth accounts that need refresh tokens."""
|
||||
# Create accounts with and without refresh tokens
|
||||
account_with_token = MagicMock(spec=OAuthAccount)
|
||||
account_with_token.oauth_name = "google"
|
||||
account_with_token.refresh_token = "refresh_token"
|
||||
|
||||
account_without_token = MagicMock(spec=OAuthAccount)
|
||||
account_without_token.oauth_name = "google"
|
||||
account_without_token.refresh_token = None
|
||||
|
||||
second_account_without_token = MagicMock(spec=OAuthAccount)
|
||||
second_account_without_token.oauth_name = "github"
|
||||
second_account_without_token.refresh_token = (
|
||||
"" # Empty string should also be treated as missing
|
||||
)
|
||||
|
||||
# Set accounts on user
|
||||
mock_user.oauth_accounts = [
|
||||
account_with_token,
|
||||
account_without_token,
|
||||
second_account_without_token,
|
||||
]
|
||||
|
||||
# Call the function under test
|
||||
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
|
||||
mock_user
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(accounts_needing_refresh) == 2
|
||||
assert account_without_token in accounts_needing_refresh
|
||||
assert second_account_without_token in accounts_needing_refresh
|
||||
assert account_with_token not in accounts_needing_refresh
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_oauth_account_has_refresh_token(
|
||||
mock_user: MagicMock, mock_oauth_account: MagicMock
|
||||
) -> None:
|
||||
"""Test checking if an OAuth account has a refresh token."""
|
||||
# Test with refresh token
|
||||
mock_oauth_account.refresh_token = "refresh_token"
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is True
|
||||
|
||||
# Test with None refresh token
|
||||
mock_oauth_account.refresh_token = None
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is False
|
||||
|
||||
# Test with empty string refresh token
|
||||
mock_oauth_account.refresh_token = ""
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_expire_oauth_token(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test the testing utility function for token expiration."""
|
||||
# Set up the mock account
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
mock_oauth_account.refresh_token = "test_refresh_token"
|
||||
mock_oauth_account.access_token = "test_access_token"
|
||||
|
||||
# Call the function under test
|
||||
result = await _test_expire_oauth_token(
|
||||
mock_user,
|
||||
mock_oauth_account,
|
||||
mock_db_session,
|
||||
mock_user_manager,
|
||||
expire_in_seconds=10,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_user_manager.user_db.update_oauth_account.assert_called_once()
|
||||
|
||||
# Verify the expiration time was set correctly
|
||||
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
|
||||
assert "expires_at" in update_data
|
||||
|
||||
# Now should be within 10-11 seconds of the set expiration
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
assert update_data["expires_at"] - now >= 8.9 # Allow 1 second for test execution
|
||||
assert update_data["expires_at"] - now <= 11.1 # Allow 1 second for test execution
|
||||
@@ -89,7 +89,8 @@ def test_run_in_background_and_wait_success() -> None:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert result == 42
|
||||
assert elapsed >= 0.1 # Verify we actually waited for the sleep
|
||||
# sometimes slightly flaky
|
||||
assert elapsed >= 0.095 # Verify we actually waited for the sleep
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
|
||||
@@ -5,7 +5,7 @@ envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/con
|
||||
echo "Waiting for API server to boot up; this may take a minute or two..."
|
||||
echo "If this takes more than ~5 minutes, check the logs of the API server container for errors with the following command:"
|
||||
echo
|
||||
echo "docker logs onyx-stack_api_server-1"
|
||||
echo "docker logs onyx-stack-api_server-1"
|
||||
echo
|
||||
|
||||
while true; do
|
||||
|
||||
1
openapi.json
Normal file
1
openapi.json
Normal file
File diff suppressed because one or more lines are too long
@@ -1079,7 +1079,7 @@ export function AssistantEditor({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<span className="text-sm ml-2">
|
||||
{values.is_public ? "Public" : "Private"}
|
||||
Organization Public
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -1088,17 +1088,22 @@ export function AssistantEditor({
|
||||
<InfoIcon size={16} className="mr-2" />
|
||||
<span className="text-sm">
|
||||
Default persona must be public. Visibility has been
|
||||
automatically set to public.
|
||||
automatically set to organization public.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.is_public ? (
|
||||
<p className="text-sm text-text-dark">
|
||||
Anyone from your team can view and use this assistant
|
||||
This assistant will be available to everyone in your
|
||||
organization
|
||||
</p>
|
||||
) : (
|
||||
<>
|
||||
<p className="text-sm text-text-dark mb-2">
|
||||
This assistant will only be available to specific
|
||||
users and groups
|
||||
</p>
|
||||
<div className="mt-2">
|
||||
<Label className="mb-2" small>
|
||||
Share with Users and Groups
|
||||
|
||||
@@ -100,7 +100,10 @@ export function EmailPasswordForm({
|
||||
// server-side provider values)
|
||||
window.location.href = "/auth/waiting-on-verification";
|
||||
} else {
|
||||
// See above comment
|
||||
// The searchparam is purely for multi tenant developement purposes.
|
||||
// It replicates the behavior of the case where a user
|
||||
// has signed up with email / password as the only user to an instance
|
||||
// and has just completed verification
|
||||
window.location.href = nextUrl
|
||||
? encodeURI(nextUrl)
|
||||
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;
|
||||
|
||||
@@ -7,7 +7,7 @@ import Text from "@/components/ui/text";
|
||||
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
|
||||
import { User } from "@/lib/types";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
export function Verify({ user }: { user: User | null }) {
|
||||
const searchParams = useSearchParams();
|
||||
const router = useRouter();
|
||||
@@ -16,6 +16,8 @@ export function Verify({ user }: { user: User | null }) {
|
||||
|
||||
const verify = useCallback(async () => {
|
||||
const token = searchParams.get("token");
|
||||
const firstUser =
|
||||
searchParams.get("first_user") && NEXT_PUBLIC_CLOUD_ENABLED;
|
||||
if (!token) {
|
||||
setError(
|
||||
"Missing verification token. Try requesting a new verification email."
|
||||
@@ -35,7 +37,7 @@ export function Verify({ user }: { user: User | null }) {
|
||||
// Use window.location.href to force a full page reload,
|
||||
// ensuring app re-initializes with the new state (including
|
||||
// server-side provider values)
|
||||
window.location.href = "/";
|
||||
window.location.href = firstUser ? "/chat?new_team=true" : "/chat";
|
||||
} else {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
setError(
|
||||
|
||||
@@ -1158,6 +1158,7 @@ export function ChatPage({
|
||||
let frozenSessionId = currentSessionId();
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
setUncaughtError(null);
|
||||
setLoadingError(null);
|
||||
|
||||
// Mark that we've sent a message for this session in the current page load
|
||||
markSessionMessageSent(frozenSessionId);
|
||||
|
||||
@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { ChatProvider } from "@/components/context/ChatContext";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
|
||||
export default async function Layout({
|
||||
children,
|
||||
@@ -41,7 +40,6 @@ export default async function Layout({
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
<ChatProvider
|
||||
value={{
|
||||
proSearchToggled,
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
"use client";
|
||||
|
||||
import React, { createContext, useContext, useState, useEffect } from "react";
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useEffect,
|
||||
useRef,
|
||||
} from "react";
|
||||
import { User, UserRole } from "@/lib/types";
|
||||
import { getCurrentUser } from "@/lib/user";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { CombinedSettings } from "@/app/admin/settings/interfaces";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { useTokenRefresh } from "@/hooks/useTokenRefresh";
|
||||
|
||||
interface UserContextType {
|
||||
user: User | null;
|
||||
@@ -93,6 +100,10 @@ export function UserProvider({
|
||||
console.error("Error fetching current user:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Use the custom token refresh hook
|
||||
useTokenRefresh(upToDateUser, fetchUser);
|
||||
|
||||
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
|
||||
84
web/src/hooks/useTokenRefresh.ts
Normal file
84
web/src/hooks/useTokenRefresh.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { User } from "@/lib/types";
|
||||
import { NO_AUTH_USER_ID } from "@/lib/extension/constants";
|
||||
|
||||
// Refresh token every 10 minutes (600000ms)
|
||||
// This is shorter than the session expiry time to ensure tokens stay valid
|
||||
const REFRESH_INTERVAL = 600000;
|
||||
|
||||
// Custom hook for handling JWT token refresh for current user
|
||||
export function useTokenRefresh(
|
||||
user: User | null,
|
||||
onRefreshFail: () => Promise<void>
|
||||
) {
|
||||
// Track last refresh time to avoid unnecessary calls
|
||||
const [lastTokenRefresh, setLastTokenRefresh] = useState<number>(Date.now());
|
||||
|
||||
// Use a ref to track first load
|
||||
const isFirstLoad = useRef(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (!user || user.id === NO_AUTH_USER_ID) return;
|
||||
|
||||
const refreshTokenPeriodically = async () => {
|
||||
try {
|
||||
// Skip time check if this is first load - we always refresh on first load
|
||||
const isTimeToRefresh =
|
||||
isFirstLoad.current ||
|
||||
Date.now() - lastTokenRefresh > REFRESH_INTERVAL - 60000;
|
||||
|
||||
if (!isTimeToRefresh) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Reset first load flag
|
||||
if (isFirstLoad.current) {
|
||||
isFirstLoad.current = false;
|
||||
}
|
||||
|
||||
const response = await fetch("/api/auth/refresh", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
// Update last refresh time on success
|
||||
setLastTokenRefresh(Date.now());
|
||||
console.debug("Auth token refreshed successfully");
|
||||
} else {
|
||||
console.warn("Failed to refresh auth token:", response.status);
|
||||
// If token refresh fails, try to get current user info
|
||||
await onRefreshFail();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error refreshing auth token:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Always attempt to refresh on first component mount
|
||||
// This helps ensure tokens are fresh, especially after browser refresh
|
||||
refreshTokenPeriodically();
|
||||
|
||||
// Set up interval for periodic refreshes
|
||||
const intervalId = setInterval(refreshTokenPeriodically, REFRESH_INTERVAL);
|
||||
|
||||
// Also refresh token on window focus, but no more than once per minute
|
||||
const handleVisibilityChange = () => {
|
||||
if (
|
||||
document.visibilityState === "visible" &&
|
||||
Date.now() - lastTokenRefresh > 60000
|
||||
) {
|
||||
refreshTokenPeriodically();
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("visibilitychange", handleVisibilityChange);
|
||||
|
||||
return () => {
|
||||
clearInterval(intervalId);
|
||||
document.removeEventListener("visibilitychange", handleVisibilityChange);
|
||||
};
|
||||
}, [user, lastTokenRefresh, onRefreshFail]);
|
||||
|
||||
return { lastTokenRefresh };
|
||||
}
|
||||
@@ -35,3 +35,5 @@ export const LocalStorageKeys = {
|
||||
export const SEARCH_PARAMS = {
|
||||
DEFAULT_SIDEBAR_OFF: "defaultSidebarOff",
|
||||
};
|
||||
|
||||
export const NO_AUTH_USER_ID = "__no_auth_user__";
|
||||
|
||||
Reference in New Issue
Block a user