Compare commits

..

3 Commits

Author SHA1 Message Date
pablonyx
5ba16d31fd improve 2025-03-06 09:38:14 -08:00
pablonyx
8a5a6d3c91 base functionality 2025-03-05 17:06:09 -08:00
pablonyx
09e4e73ba6 k 2025-03-05 13:13:24 -08:00
41 changed files with 804 additions and 457 deletions

View File

@@ -1,125 +0,0 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -27,6 +27,7 @@ from ee.onyx.server.reporting.usage_export_api import router as usage_export_rou
from ee.onyx.server.saml import router as saml_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.tenants.router import router as new_router
from ee.onyx.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@@ -123,12 +124,13 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, user_group_router)
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, new_router)
include_router_with_global_prefix_prepended(application, query_history_router)
# EE only backend APIs
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -25,8 +25,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
) -> Response:
try:
if MULTI_TENANT:
print("Shold set tenant id")
tenant_id = await _get_tenant_id_from_request(request, logger)
print(f"Tenant id: {tenant_id}")
else:
print("Should not set tenant id")
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -64,8 +67,9 @@ async def _get_tenant_id_from_request(
try:
# Look up token data in Redis
print("I AM IN THIS FUNCTION 7")
token_data = await retrieve_auth_token_data_from_redis(request)
print("I AM IN THIS FUNCTION 8")
if not token_data:
logger.debug(

View File

@@ -80,7 +80,6 @@ class ConfluenceCloudOAuth:
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)

View File

@@ -48,5 +48,4 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import time
import uuid
import aiohttp # Async HTTP client
@@ -54,24 +55,26 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
) -> tuple[str, bool]:
"""Get existing tenant ID for an email or create a new tenant if none exists.
Returns:
tuple: (tenant_id, is_newly_created) - The tenant ID and a boolean indicating if it was newly created
"""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
return POSTGRES_DEFAULT_SCHEMA, False
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
is_newly_created = False
try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
is_newly_created = True
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
@@ -81,14 +84,22 @@ async def get_or_provision_tenant(
status_code=401, detail="User does not belong to an organization"
)
return tenant_id
return tenant_id, is_newly_created
async def create_tenant(email: str, referral_source: str | None = None) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
start_time = time.time()
await provision_tenant(tenant_id, email)
duration = time.time() - start_time
logger.error(
f"Tenant provisioning for {tenant_id} completed in {duration:.2f} seconds"
)
print(
f"Tenant provisioning for {tenant_id} completed in {duration:.2f} seconds"
)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
@@ -119,36 +130,23 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Await the Alembic migrations up to the specified revision
start_time = time.time()
await asyncio.to_thread(run_alembic_migrations, tenant_id, "465f78d9b7f9")
duration = time.time() - start_time
print(
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
)
logger.info(
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
)
logger.error(
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
)
# Add users to tenant - this is needed for authentication
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
@@ -353,3 +351,60 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
async def complete_tenant_setup(tenant_id: str) -> None:
"""Complete the tenant setup process after user creation.
This function handles the remaining steps of tenant provisioning after the initial
schema creation and user authentication:
1. Completes the remaining Alembic migrations
2. Configures default API keys
3. Sets up Onyx
4. Creates milestone record
"""
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
logger.debug(f"Completing setup for tenant {tenant_id}")
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Complete the remaining Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={},
db_session=db_session,
)
logger.info(f"Tenant setup completed for {tenant_id}")
except Exception as e:
logger.exception(f"Failed to complete tenant setup for {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to complete tenant setup: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -0,0 +1,40 @@
import logging
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.server.tenants.provisioning import complete_tenant_setup
from onyx.auth.users import optional_minimal_user
from onyx.db.models import MinimalUser
from shared_configs.contextvars import get_current_tenant_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tenants", tags=["tenants"])
class CompleteTenantSetupRequest(BaseModel):
email: str
@router.post("/complete-setup")
async def api_complete_tenant_setup(
request: CompleteTenantSetupRequest,
user: MinimalUser = Depends(optional_minimal_user),
) -> None:
"""Complete the tenant setup process for a user.
This endpoint is called from the frontend after user creation to complete
the tenant setup process (migrations, seeding, etc.).
"""
tenant_id = get_current_tenant_id()
try:
await complete_tenant_setup(tenant_id)
return {"status": "success"}
except Exception as e:
logger.error(f"Failed to complete tenant setup: {e}")
raise HTTPException(status_code=500, detail="Failed to complete tenant setup")

View File

@@ -14,8 +14,10 @@ from onyx.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
def run_alembic_migrations(schema_name: str, target_revision: str = "head") -> None:
logger.info(
f"Starting Alembic migrations for schema: {schema_name} to target: {target_revision}"
)
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -37,11 +39,10 @@ def run_alembic_migrations(schema_name: str) -> None:
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
command.upgrade(alembic_cfg, target_revision)
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
f"Alembic migrations completed successfully for schema: {schema_name} to target: {target_revision}"
)
except Exception as e:

View File

@@ -2,6 +2,7 @@ import json
import random
import secrets
import string
import time
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
@@ -13,6 +14,7 @@ from typing import Optional
from typing import Tuple
import jwt
import sqlalchemy.exc
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
@@ -90,6 +92,7 @@ from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import MinimalUser
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
@@ -186,6 +189,7 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
def verify_email_is_invited(email: str) -> None:
return None
whitelist = get_invited_users()
if not whitelist:
return
@@ -215,6 +219,7 @@ def verify_email_is_invited(email: str) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
return None
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -235,6 +240,36 @@ def verify_email_domain(email: str) -> None:
)
class SimpleUserManager(UUIDIDMixin, BaseUserManager[MinimalUser, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[MinimalUser, uuid.UUID]
async def get(self, id: uuid.UUID) -> MinimalUser:
"""Get a user by id, with error handling for partial database provisioning."""
try:
return await super().get(id)
except sqlalchemy.exc.ProgrammingError as e:
# Handle database schema mismatch during partial provisioning
if "column user.temperature_override_enabled does not exist" in str(e):
# Create a minimal user with just the required fields
# This is a temporary solution during partial provisioning
from onyx.db.models import MinimalUser
return MinimalUser(
id=id,
email="temp@example.com", # Will be replaced with actual data
hashed_password="",
is_active=True,
is_verified=True,
is_superuser=False,
role="BASIC",
)
# Re-raise other database errors
raise
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -247,8 +282,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)(user_email)
async with get_async_session_with_tenant(tenant_id) as db_session:
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
tenant_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
db_session, MinimalUser, OAuthAccount
)
user = await tenant_user_db.get_by_email(user_email)
else:
@@ -265,6 +300,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
start_time = time.time()
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
user_create.password, cast(schemas.UC, user_create)
@@ -277,7 +313,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else None
)
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -286,6 +322,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
print(f"Tenant ID: {tenant_id}, Is newly created: {is_newly_created}")
print("duration: ", time.time() - start_time)
user: User
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -309,7 +347,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
user_create.role = UserRole.BASIC
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
simple_tennat_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
db_session, MinimalUser, OAuthAccount
)
user = await SimpleUserManager(simple_tennat_user_db).create(
user_create, safe=safe, request=request
) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
@@ -374,7 +417,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
getattr(request.state, "referral_source", None) if request else None
)
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -511,7 +554,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -563,7 +606,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Your admin has not enabled this feature.",
)
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -587,20 +630,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
# Get tenant_id from mapping
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -679,6 +716,12 @@ async def get_user_manager(
yield UserManager(user_db)
async def get_minimal_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
) -> AsyncGenerator[SimpleUserManager, None]:
yield SimpleUserManager(user_db)
cookie_transport = CookieTransport(
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
cookie_secure=WEB_DOMAIN.startswith("https"),
@@ -690,6 +733,10 @@ def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
def get_minimal_redis_strategy() -> RedisStrategy:
return CustomRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
@@ -715,7 +762,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
async def write_token(self, user: User) -> str:
redis = await get_async_redis_connection()
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -755,14 +802,139 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
await redis.delete(f"{self.key_prefix}{token}")
class CustomRedisStrategy(RedisStrategy[MinimalUser, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
"""
def __init__(
self,
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
self.key_prefix = key_prefix
async def write_token(self, user: MinimalUser) -> str:
redis = await get_async_redis_connection()
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(email=user.email)
token_data = {
"sub": str(user.id),
"tenant_id": tenant_id,
}
token = secrets.token_urlsafe()
await redis.set(
f"{self.key_prefix}{token}",
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
async def read_token(
self,
token: Optional[str],
user_manager: BaseUserManager[MinimalUser, uuid.UUID],
) -> Optional[MinimalUser]:
redis = await get_async_redis_connection()
token_data_str = await redis.get(f"{self.key_prefix}{token}")
if not token_data_str:
return None
try:
token_data = json.loads(token_data_str)
user_id = token_data["sub"]
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
return None
async def destroy_token(self, token: str, user: MinimalUser) -> None:
"""Properly delete the token from async redis."""
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
# class CustomRedisStrategy(RedisStrategy[MinimalUser, uuid.UUID]):
# """Custom Redis strategy that handles database schema mismatches during partial provisioning."""
# def __init__(
# self,
# lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
# key_prefix: str = REDIS_AUTH_KEY_PREFIX,
# ):
# self.lifetime_seconds = lifetime_seconds
# self.key_prefix = key_prefix
# async def read_token(
# self, token: Optional[str], user_manager: BaseUserManager[MinimalUser, uuid.UUID]
# ) -> Optional[MinimalUser]:
# try:
# redis = await get_async_redis_connection()
# token_data_str = await redis.get(f"{self.key_prefix}{token}")
# if not token_data_str:
# return None
# try:
# token_data = json.loads(token_data_str)
# user_id = token_data["sub"]
# parsed_id = user_manager.parse_id(user_id)
# return await user_manager.get(parsed_id)
# except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
# return None
# except sqlalchemy.exc.ProgrammingError as e:
# # Handle database schema mismatch during partial provisioning
# if "column user.temperature_override_enabled does not exist" in str(e):
# # Return None to allow unauthenticated access during partial provisioning
# return None
# # Re-raise other database errors
# raise
# async def write_token(self, user: MinimalUser) -> str:
# redis = await get_async_redis_connection()
# token = generate_jwt(
# data={"sub": str(user.id)},
# secret=USER_AUTH_SECRET,
# lifetime_seconds=self.lifetime_seconds,
# )
# await redis.set(
# f"{self.key_prefix}{token}",
# json.dumps({"sub": str(user.id)}),
# ex=self.lifetime_seconds,
# )
# return token
# async def destroy_token(self, token: str, user: MinimalUser) -> None:
# """Properly delete the token from async redis."""
# redis = await get_async_redis_connection()
# await redis.delete(f"{self.key_prefix}{token}")
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
minimal_auth_backend = AuthenticationBackend(
name="redis",
transport=cookie_transport,
get_strategy=get_minimal_redis_strategy,
)
elif AUTH_BACKEND == AuthBackend.POSTGRES:
auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
minimal_auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
else:
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
@@ -809,6 +981,9 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
)
fastapi_minimal_users = FastAPIUserWithLogoutRouter[MinimalUser, uuid.UUID](
get_minimal_user_manager, [minimal_auth_backend]
)
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
# take care of that in `double_check_user` ourself. This is needed, since
@@ -817,6 +992,30 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
optional_minimal_user_dependency = fastapi_minimal_users.current_user(
active=True, optional=True
)
async def optional_minimal_user(
user: MinimalUser | None = Depends(optional_minimal_user_dependency),
) -> MinimalUser | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
print("I AM IN THIS FUNCTION 1")
try:
print("I AM IN THIS FUNCTION 2")
return user
except sqlalchemy.exc.ProgrammingError as e:
print("I AM IN THIS FUNCTION 3")
# Handle database schema mismatch during partial provisioning
if "column user.temperature_override_enabled does not exist" in str(e):
# Return None to allow unauthenticated access during partial provisioning
return None
# Re-raise other database errors
raise
async def optional_user_(
request: Request,
user: User | None,

View File

@@ -240,7 +240,7 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
# Get the page content
page_content = extract_text_from_confluence_html(

View File

@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
repo_name: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.repo_name = repo_name
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
@@ -157,42 +157,11 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logger.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
@@ -220,17 +189,11 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
for repo in repos:
if self.include_prs:
@@ -305,48 +268,11 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
@@ -372,15 +298,10 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
@@ -389,7 +310,6 @@ class GithubConnector(LoadConnector, PollConnector):
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
@@ -401,7 +321,7 @@ if __name__ == "__main__":
connector = GithubConnector(
repo_owner=os.environ["REPO_OWNER"],
repositories=os.environ["REPOSITORIES"],
repo_name=os.environ["REPO_NAME"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}

View File

@@ -180,6 +180,7 @@ SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool:
print(f"Checking if {name} is valid")
return SCHEMA_NAME_REGEX.match(name) is not None
@@ -474,17 +475,27 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
print(f"Retrieved tenant_id: {tenant_id}")
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
print("Authentication error: User must authenticate")
raise BasicAuthenticationError(detail="User must authenticate")
engine = get_sqlalchemy_engine()
print("SQLAlchemy engine obtained")
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
print("MULTI_TENANT mode enabled")
if not is_valid_schema_name(tenant_id):
print(f"Invalid tenant ID detected: {tenant_id}")
raise HTTPException(status_code=400, detail="Invalid tenant ID")
print(f"Setting search_path to tenant schema: {tenant_id}")
session.execute(text(f'SET search_path = "{tenant_id}"'))
else:
print("MULTI_TENANT mode disabled")
yield session
print("Session yielded and closed")
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:

View File

@@ -2318,3 +2318,16 @@ class TenantAnonymousUserPath(Base):
anonymous_user_path: Mapped[str] = mapped_column(
String, nullable=False, unique=True
)
class AdditionalBase(DeclarativeBase):
__abstract__ = True
class MinimalUser(SQLAlchemyBaseUserTableUUID, AdditionalBase):
# oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
# "OAuthAccount", lazy="joined", cascade="all, delete-orphan"
# )
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)

View File

@@ -464,29 +464,12 @@ def index_doc_batch(
),
)
all_returned_doc_ids = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
)
if all_returned_doc_ids != set(updatable_ids):
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"Successful IDs: {successful_doc_ids}"
)
last_modified_ids = []

View File

@@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
@@ -323,7 +322,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, standard_oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -315,13 +315,20 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
try:
redis = await get_async_redis_connection()
redis_key = REDIS_AUTH_KEY_PREFIX + token
print("[Redis] Obtained async Redis connection")
redis_key = f"{REDIS_AUTH_KEY_PREFIX}{token}"
print(f"[Redis] Fetching token data for key: {redis_key}")
token_data_str = await redis.get(redis_key)
print(
f"[Redis] Retrieved token data: {'found' if token_data_str else 'not found'}"
)
if not token_data_str:
logger.debug(f"Token key {redis_key} not found or expired in Redis")
logger.debug(f"[Redis] Token key '{redis_key}' not found or expired")
return None
print("[Redis] Decoding token data")
print(f"[Redis] Token data: {token_data_str}")
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")

View File

@@ -10,6 +10,7 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import optional_minimal_user
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.server.onyx_api.ingestion import api_key_dep
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -114,6 +115,7 @@ def check_router_auth(
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == control_plane_dep
or depends_fn == optional_minimal_user
or depends_fn == current_cloud_superuser
):
found_auth = True

View File

@@ -112,6 +112,12 @@ class UserInfo(BaseModel):
)
class MinimalUserInfo(BaseModel):
id: str
email: str
is_active: bool
class UserByEmail(BaseModel):
user_email: str

View File

@@ -32,6 +32,7 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_minimal_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DEV_MODE
@@ -44,6 +45,7 @@ from onyx.db.api_key import is_api_key_email_address
from onyx.db.auth import get_total_users_count
from onyx.db.engine import get_session
from onyx.db.models import AccessToken
from onyx.db.models import MinimalUser
from onyx.db.models import User
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_all_users
@@ -55,6 +57,7 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import MinimalUserInfo
from onyx.server.manage.models import UserByEmail
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPreferences
@@ -534,6 +537,27 @@ def get_current_token_creation(
return None
@router.get("/me-info")
def verify_user_attempting_to_login(
request: Request,
user: MinimalUser | None = Depends(optional_minimal_user),
# db_session: Session = Depends(get_session),
) -> MinimalUserInfo:
# Check if the authentication cookie exists
# Print cookie names for debugging
cookie_names = list(request.cookies.keys())
logger.info(f"Available cookies: {cookie_names}")
if not request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME):
raise HTTPException(status_code=401, detail="User not found")
# print("I AM IN THIS FUNCTION 4")
# if user is None:
# print("I AM IN THIS FUNCTION 5")
# raise HTTPException(status_code=401, detail="User not found")
# print("I AM IN THIS FUNCTION 6")
return MinimalUserInfo(id="", email="", is_active=True)
@router.get("/me")
def verify_user_logged_in(
user: User | None = Depends(optional_user),

View File

@@ -43,6 +43,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
def get_or_generate_uuid() -> str:
return "hi"
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.

View File

@@ -45,7 +45,7 @@ def test_confluence_connector_basic(
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 2
assert len(doc_batch) == 3
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None

View File

@@ -80,13 +80,3 @@ prod cluster**
- `kubectl delete -f .`
- To not delete the persistent volumes (Document indexes and Users), specify the specific `.yaml` files instead of
`.` without specifying delete on persistent-volumes.yaml.
### Using Helm to deploy to an existing cluster
Onyx has a helm chart that is convenient to install all services to an existing Kubernetes cluster. To install:
* Currently the helm chart is not published so to install, clone the repo.
* Configure access to the cluster via kubectl. Ensure the kubectl context is set to the cluster that you want to use
* The default secrets, environment variables and other service level configuration are stored in `deployment/helm/charts/onyx/values.yml`. You may create another `override.yml`
* `cd deployment/helm/charts/onyx` and run `helm install onyx -n onyx -f override.yaml .`. This will install onyx on the cluster under the `onyx` namespace.
* Check the status of the deploy using `kubectl get pods -n onyx`

View File

@@ -1,27 +0,0 @@
{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "onyx-stack.fullname" . }}-ingress-api
annotations:
kubernetes.io/ingress.class: nginx
nginx.ingress.kubernetes.io/rewrite-target: /$2
nginx.ingress.kubernetes.io/use-regex: "true"
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
spec:
rules:
- host: {{ .Values.ingress.api.host }}
http:
paths:
- path: /api(/|$)(.*)
pathType: Prefix
backend:
service:
name: {{ include "onyx-stack.fullname" . }}-api-service
port:
number: {{ .Values.api.service.servicePort }}
tls:
- hosts:
- {{ .Values.ingress.api.host }}
secretName: {{ include "onyx-stack.fullname" . }}-ingress-api-tls
{{- end }}

View File

@@ -1,26 +0,0 @@
{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "onyx-stack.fullname" . }}-ingress-webserver
annotations:
kubernetes.io/ingress.class: nginx
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
kubernetes.io/tls-acme: "true"
spec:
rules:
- host: {{ .Values.ingress.webserver.host }}
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: {{ include "onyx-stack.fullname" . }}-webserver
port:
number: {{ .Values.webserver.service.servicePort }}
tls:
- hosts:
- {{ .Values.ingress.webserver.host }}
secretName: {{ include "onyx-stack.fullname" . }}-ingress-webserver-tls
{{- end }}

View File

@@ -1,20 +0,0 @@
{{- if .Values.letsencrypt.enabled -}}
apiVersion: cert-manager.io/v1
kind: ClusterIssuer
metadata:
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
spec:
acme:
# The ACME server URL
server: https://acme-v02.api.letsencrypt.org/directory
# Email address used for ACME registration
email: {{ .Values.letsencrypt.email }}
# Name of a secret used to store the ACME account private key
privateKeySecretRef:
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
# Enable the HTTP-01 challenge provider
solvers:
- http01:
ingress:
class: nginx
{{- end }}

View File

@@ -376,17 +376,22 @@ redis:
existingSecret: onyx-secrets
existingSecretPasswordKey: redis_password
ingress:
enabled: false
className: ""
api:
host: onyx.local
webserver:
host: onyx.local
# ingress:
# enabled: false
# className: ""
# annotations: {}
# # kubernetes.io/ingress.class: nginx
# # kubernetes.io/tls-acme: "true"
# hosts:
# - host: chart-example.local
# paths:
# - path: /
# pathType: ImplementationSpecific
# tls: []
# # - secretName: chart-example-tls
# # hosts:
# # - chart-example.local
letsencrypt:
enabled: false
email: "abc@abc.com"
auth:
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets

View File

@@ -290,24 +290,21 @@ export function SettingsForm() {
id="chatRetentionInput"
placeholder="Infinite Retention"
/>
<div className="mr-auto flex gap-2">
<Button
onClick={handleSetChatRetention}
variant="submit"
size="sm"
className="mr-auto"
>
Set Retention Limit
</Button>
<Button
onClick={handleClearChatRetention}
variant="default"
size="sm"
className="mr-auto"
>
Retain All
</Button>
</div>
<Button
onClick={handleSetChatRetention}
variant="submit"
size="sm"
className="mr-3"
>
Set Retention Limit
</Button>
<Button
onClick={handleClearChatRetention}
variant="default"
size="sm"
>
Retain All
</Button>
</>
)}

View File

@@ -61,7 +61,6 @@ export function EmailPasswordForm({
if (!response.ok) {
setIsWorking(false);
const errorDetail = (await response.json()).detail;
let errorMsg = "Unknown error";
if (typeof errorDetail === "object" && errorDetail.reason) {
@@ -97,13 +96,12 @@ export function EmailPasswordForm({
} else {
setIsWorking(false);
const errorDetail = (await loginResponse.json()).detail;
let errorMsg = "Unknown error";
if (errorDetail === "LOGIN_BAD_CREDENTIALS") {
errorMsg = "Invalid email or password";
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
errorMsg = "Create an account to set a password";
} else if (typeof errorDetail === "string") {
errorMsg = errorDetail;
}
if (loginResponse.status === 429) {
errorMsg = "Too many requests. Please try again later.";

View File

@@ -8,8 +8,6 @@ import {
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
import Text from "@/components/ui/text";
import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import ReferralSourceSelector from "./ReferralSourceSelector";

View File

@@ -0,0 +1,242 @@
"use client";
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { MinimalUserInfo } from "@/lib/types";
import Text from "@/components/ui/text";
import { Logo } from "@/components/logo/Logo";
import { completeTenantSetup } from "@/lib/tenant";
import { useEffect, useState, useRef } from "react";
import {
Card,
CardContent,
CardHeader,
CardFooter,
} from "@/components/ui/card";
import { LoadingSpinner } from "@/app/chat/chat_search/LoadingSpinner";
import { Button } from "@/components/ui/button";
import { logout } from "@/lib/user";
import { FiLogOut } from "react-icons/fi";
export default function WaitingOnSetupPage({
minimalUserInfo,
}: {
minimalUserInfo: MinimalUserInfo;
}) {
const [progress, setProgress] = useState(0);
const [setupStage, setSetupStage] = useState<string>(
"Setting up your account"
);
const progressRef = useRef<number>(0);
const animationRef = useRef<number>();
const startTimeRef = useRef<number>(Date.now());
const [isReady, setIsReady] = useState(false);
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
// Setup stages that will cycle through during the loading process
const setupStages = [
"Setting up your account",
"Configuring workspace",
"Preparing resources",
"Setting up permissions",
"Finalizing setup",
];
// Function to poll the /api/me endpoint
const pollAccountStatus = async () => {
try {
const response = await fetch("/api/me");
if (response.status === 200) {
// Account is ready
setIsReady(true);
if (pollIntervalRef.current) {
clearInterval(pollIntervalRef.current);
}
return true;
}
} catch (error) {
console.error("Error polling account status:", error);
}
return false;
};
// Handle logout
const handleLogout = async () => {
try {
await logout();
window.location.href = "/auth/login";
} catch (error) {
console.error("Failed to logout:", error);
}
};
useEffect(() => {
// Animation setup for progress bar
let lastUpdateTime = 0;
const updateInterval = 100;
const normalAnimationDuration = 30000; // 30 seconds for normal animation
const acceleratedAnimationDuration = 1500; // 1.5 seconds for accelerated animation after ready
let currentStageIndex = 0;
let lastStageUpdateTime = Date.now();
const stageRotationInterval = 3000; // Rotate stages every 3 seconds
const animate = (timestamp: number) => {
const elapsedTime = timestamp - startTimeRef.current;
const now = Date.now();
// Calculate progress using different curves based on ready status
const maxProgress = 99;
let progress;
if (isReady) {
// Accelerate to 100% when account is ready
progress =
maxProgress +
(100 - maxProgress) *
((now - startTimeRef.current) / acceleratedAnimationDuration);
if (progress >= 100) progress = 100;
} else {
// Slower progress when still waiting
progress =
maxProgress * (1 - Math.exp(-elapsedTime / normalAnimationDuration));
}
// Update progress if enough time has passed
if (timestamp - lastUpdateTime > updateInterval) {
progressRef.current = progress;
setProgress(Math.round(progress * 10) / 10);
// Cycle through setup stages
if (now - lastStageUpdateTime > stageRotationInterval && !isReady) {
currentStageIndex = (currentStageIndex + 1) % setupStages.length;
setSetupStage(setupStages[currentStageIndex]);
lastStageUpdateTime = now;
}
lastUpdateTime = timestamp;
}
// Continue animation if not completed
if (progress < 100) {
animationRef.current = requestAnimationFrame(animate);
} else if (progress >= 100) {
// Redirect when progress reaches 100%
setSetupStage("Setup complete!");
setTimeout(() => {
window.location.href = "/chat";
}, 500);
}
};
// Start animation
startTimeRef.current = performance.now();
animationRef.current = requestAnimationFrame(animate);
// Start polling the /api/me endpoint
pollIntervalRef.current = setInterval(async () => {
const ready = await pollAccountStatus();
if (ready) {
// If ready, we'll let the animation handle the redirect
console.log("Account is ready!");
}
}, 2000); // Poll every 2 seconds
// Attempt to complete tenant setup
// completeTenantSetup(minimalUserInfo.email).catch((error) => {
// console.error("Failed to complete tenant setup:", error);
// });
// Cleanup function
return () => {
if (animationRef.current) {
cancelAnimationFrame(animationRef.current);
}
if (pollIntervalRef.current) {
clearInterval(pollIntervalRef.current);
}
};
}, [isReady, minimalUserInfo.email]);
useEffect(() => {
completeTenantSetup(minimalUserInfo.email).catch((error) => {
console.error("Failed to complete tenant setup:", error);
});
}, []);
return (
<main className="min-h-screen bg-gradient-to-b from-white to-neutral-50 dark:from-neutral-900 dark:to-neutral-950">
<div className="absolute top-0 w-full">
<HealthCheckBanner />
</div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div className="w-full max-w-md">
<div className="flex flex-col items-center mb-8">
<Logo height={80} width={80} className="mx-auto w-fit mb-6" />
<h1 className="text-2xl font-bold text-neutral-900 dark:text-white">
Account Setup
</h1>
</div>
<Card className="border-neutral-200 dark:border-neutral-800 shadow-lg">
<CardHeader className="pb-0">
<div className="flex items-center justify-between">
<div className="flex items-center space-x-3">
<div className="relative">
<LoadingSpinner size="medium" className="text-primary" />
</div>
<h2 className="text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{setupStage}
</h2>
</div>
<span className="text-sm font-medium text-neutral-500 dark:text-neutral-400">
{progress}%
</span>
</div>
</CardHeader>
<CardContent className="pt-4">
{/* Progress bar */}
<div className="w-full h-2 bg-neutral-200 dark:bg-neutral-700 rounded-full mb-6 overflow-hidden">
<div
className="h-full bg-blue-600 dark:bg-blue-500 rounded-full transition-all duration-300 ease-out"
style={{ width: `${progress}%` }}
/>
</div>
<div className="space-y-4">
<div className="flex flex-col space-y-1">
<Text className="text-neutral-800 dark:text-neutral-200 font-medium">
Welcome,{" "}
<span className="font-semibold">
{minimalUserInfo?.email}
</span>
</Text>
<Text className="text-neutral-600 dark:text-neutral-400 text-sm">
We're setting up your account. This may take a few moments.
</Text>
</div>
<div className="bg-blue-50 dark:bg-blue-900/20 rounded-lg p-4 border border-blue-100 dark:border-blue-800">
<Text className="text-sm text-blue-700 dark:text-blue-300">
You'll be redirected automatically when your account is
ready. If you're not redirected within a minute, please
refresh the page.
</Text>
</div>
</div>
</CardContent>
<CardFooter className="flex justify-end pt-4">
<Button
variant="outline"
size="sm"
onClick={handleLogout}
className="text-neutral-600 dark:text-neutral-300"
>
<FiLogOut className="mr-1" />
Logout
</Button>
</CardFooter>
</Card>
</div>
</div>
</main>
);
}

View File

@@ -191,7 +191,6 @@ export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
onChange={(e) => setNewFolderName(e.target.value)}
className="text-sm font-medium bg-transparent outline-none w-full pb-1 border-b border-background-500 transition-colors duration-200"
onKeyDown={(e) => {
e.stopPropagation();
if (e.key === "Enter") {
handleEdit();
}

View File

@@ -303,6 +303,7 @@ const FolderItem = ({
key={chatSession.id}
chatSession={chatSession}
isSelected={chatSession.id === currentChatId}
skipGradient={isDragOver}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
/>

View File

@@ -32,17 +32,21 @@ export function ChatSessionDisplay({
chatSession,
search,
isSelected,
skipGradient,
closeSidebar,
showShareModal,
showDeleteModal,
foldersExisting,
isDragging,
}: {
chatSession: ChatSession;
isSelected: boolean;
search?: boolean;
skipGradient?: boolean;
closeSidebar?: () => void;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
foldersExisting?: boolean;
isDragging?: boolean;
}) {
const router = useRouter();
@@ -234,12 +238,8 @@ export function ChatSessionDisplay({
e.preventDefault();
e.stopPropagation();
}}
onChange={(e) => {
setChatName(e.target.value);
}}
onChange={(e) => setChatName(e.target.value)}
onKeyDown={(event) => {
event.stopPropagation();
if (event.key === "Enter") {
onRename();
event.preventDefault();

View File

@@ -264,6 +264,7 @@ export function PagesTab({
>
<ChatSessionDisplay
chatSession={chat}
foldersExisting={foldersExisting}
isSelected={currentChatId === chat.id}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}

View File

@@ -20,7 +20,7 @@ import {
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
import { AppProvider } from "@/components/context/AppProvider";
import { PHProvider } from "./providers";
import { getCurrentUserSS } from "@/lib/userSS";
import { getCurrentUserSS, getMinimalUserInfoSS } from "@/lib/userSS";
import { Suspense } from "react";
import PostHogPageView from "./PostHogPageView";
import Script from "next/script";
@@ -30,6 +30,8 @@ import { ThemeProvider } from "next-themes";
import CloudError from "@/components/errorPages/CloudErrorPage";
import Error from "@/components/errorPages/ErrorPage";
import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage";
import CompleteTenantSetupPage from "./auth/waiting-on-setup/WaitingOnSetup";
import WaitingOnSetupPage from "./auth/waiting-on-setup/WaitingOnSetup";
const inter = Inter({
subsets: ["latin"],
@@ -70,12 +72,17 @@ export default async function RootLayout({
}: {
children: React.ReactNode;
}) {
const [combinedSettings, assistantsData, user] = await Promise.all([
fetchSettingsSS(),
fetchAssistantData(),
getCurrentUserSS(),
]);
const [combinedSettings, assistantsData, user, minimalUserInfo] =
await Promise.all([
fetchSettingsSS(),
fetchAssistantData(),
getCurrentUserSS(),
getMinimalUserInfoSS(),
]);
// if (!user && minimalUserInfo) {
// return <CompleteTenantSetupPage />;
// }
const productGating =
combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE;
@@ -135,6 +142,11 @@ export default async function RootLayout({
if (productGating === ApplicationStatus.GATED_ACCESS) {
return getPageContent(<AccessRestrictedPage />);
}
if (!user && minimalUserInfo) {
return getPageContent(
<WaitingOnSetupPage minimalUserInfo={minimalUserInfo} />
);
}
if (!combinedSettings) {
return getPageContent(

View File

@@ -40,12 +40,8 @@ export const ConnectorTitle = ({
const typedConnector = connector as Connector<GithubConfig>;
additionalMetadata.set(
"Repo",
typedConnector.connector_specific_config.repositories
? `${typedConnector.connector_specific_config.repo_owner}/${
typedConnector.connector_specific_config.repositories.includes(",")
? "multiple repos"
: typedConnector.connector_specific_config.repositories
}`
typedConnector.connector_specific_config.repo_name
? `${typedConnector.connector_specific_config.repo_owner}/${typedConnector.connector_specific_config.repo_name}`
: `${typedConnector.connector_specific_config.repo_owner}/*`
);
} else if (connector.source === "gitlab") {

View File

@@ -190,12 +190,10 @@ export const connectorConfigs: Record<
fields: [
{
type: "text",
query: "Enter the repository name(s):",
label: "Repository Name(s)",
name: "repositories",
query: "Enter the repository name:",
label: "Repository Name",
name: "repo_name",
optional: false,
description:
"For multiple repositories, enter comma-separated names (e.g., repo1,repo2,repo3)",
},
],
},
@@ -1360,7 +1358,7 @@ export interface WebConfig {
export interface GithubConfig {
repo_owner: string;
repositories: string; // Comma-separated list of repository names
repo_name: string;
include_prs: boolean;
include_issues: boolean;
}

14
web/src/lib/tenant.ts Normal file
View File

@@ -0,0 +1,14 @@
export async function completeTenantSetup(email: string): Promise<void> {
const response = await fetch(`/api/tenants/complete-setup`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email }),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Failed to complete tenant setup: ${errorText}`);
}
}

View File

@@ -15,6 +15,12 @@ interface UserPreferences {
temperature_override_enabled: boolean;
}
export interface MinimalUserInfo {
id: string;
email: string;
is_active: boolean;
}
export enum UserRole {
LIMITED = "limited",
BASIC = "basic",

View File

@@ -1,5 +1,5 @@
import { cookies } from "next/headers";
import { User } from "./types";
import { MinimalUserInfo, User } from "./types";
import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
@@ -162,6 +162,30 @@ export const logoutSS = async (
}
}
};
export const getMinimalUserInfoSS =
async (): Promise<MinimalUserInfo | null> => {
try {
const response = await fetch(buildUrl("/me-info"), {
credentials: "include",
headers: {
cookie: (await cookies())
.getAll()
.map((cookie) => `${cookie.name}=${cookie.value}`)
.join("; "),
},
next: { revalidate: 0 },
});
if (!response.ok) {
console.error("Failed to fetch minimal user info");
return null;
}
return await response.json();
} catch (e) {
console.log(`Error fetching minimal user info: ${e}`);
return null;
}
};
export const getCurrentUserSS = async (): Promise<User | null> => {
try {