Compare commits

..

3 Commits

Author SHA1 Message Date
pablonyx
5ba16d31fd improve 2025-03-06 09:38:14 -08:00
pablonyx
8a5a6d3c91 base functionality 2025-03-05 17:06:09 -08:00
pablonyx
09e4e73ba6 k 2025-03-05 13:13:24 -08:00
21 changed files with 758 additions and 400 deletions

View File

@@ -6,8 +6,7 @@ Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
import time
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
@@ -28,357 +27,45 @@ depends_on = None
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade():
# --- PART 1: chat_message table ---
# Step 1: Add nullable column (quick, minimal locking)
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
# op.execute("DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message")
# op.execute("DROP FUNCTION IF EXISTS update_chat_message_tsv()")
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
# # Drop chat_session tsv trigger if it exists
# op.execute("DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session")
# op.execute("DROP FUNCTION IF EXISTS update_chat_session_tsv()")
# op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS title_tsv")
# raise Exception("Stop here")
time.time()
op.execute("ALTER TABLE chat_message ADD COLUMN IF NOT EXISTS message_tsv tsvector")
# Step 2: Create function and trigger for new/updated rows
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
CREATE OR REPLACE FUNCTION update_chat_message_tsv()
RETURNS TRIGGER AS $$
BEGIN
NEW.message_tsv = to_tsvector('english', NEW.message);
RETURN NEW;
END;
$$ LANGUAGE plpgsql
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Create trigger in a separate execute call
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE TRIGGER chat_message_tsv_trigger
BEFORE INSERT OR UPDATE ON chat_message
FOR EACH ROW EXECUTE FUNCTION update_chat_message_tsv()
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Step 3: Update existing rows in batches using Python
time.time()
# Get connection and count total rows
connection = op.get_bind()
total_count_result = connection.execute(
text("SELECT COUNT(*) FROM chat_message")
).scalar()
total_count = total_count_result if total_count_result is not None else 0
batch_size = 5000
batches = 0
# Calculate total batches needed
total_batches = (
(total_count + batch_size - 1) // batch_size if total_count > 0 else 0
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Process in batches - properly handling UUIDs by using OFFSET/LIMIT approach
for batch_num in range(total_batches):
offset = batch_num * batch_size
# Commit again before creating the second concurrent index
op.execute("COMMIT")
# Execute update for this batch using OFFSET/LIMIT which works with UUIDs
connection.execute(
text(
"""
UPDATE chat_message
SET message_tsv = to_tsvector('english', message)
WHERE id IN (
SELECT id FROM chat_message
WHERE message_tsv IS NULL
ORDER BY id
LIMIT :batch_size OFFSET :offset
)
"""
).bindparams(batch_size=batch_size, offset=offset)
)
# Commit each batch
connection.execute(text("COMMIT"))
# Start a new transaction
connection.execute(text("BEGIN"))
batches += 1
# Final check for any remaining NULL values
connection.execute(
text(
"""
UPDATE chat_message SET message_tsv = to_tsvector('english', message)
WHERE message_tsv IS NULL
"""
)
)
# Create GIN index concurrently
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message USING GIN (message_tsv)
"""
)
)
# First drop the trigger as it won't be needed anymore
connection.execute(
text(
"""
DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message;
"""
)
)
connection.execute(
text(
"""
DROP FUNCTION IF EXISTS update_chat_message_tsv();
"""
)
)
# Add new generated column
time.time()
connection.execute(
text(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv_gen tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
)
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv_gen
ON chat_message USING GIN (message_tsv_gen)
"""
)
)
# Drop old index and column
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;
"""
)
)
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
ALTER TABLE chat_message DROP COLUMN message_tsv;
"""
)
)
# Rename new column to old name
connection.execute(
text(
"""
ALTER TABLE chat_message RENAME COLUMN message_tsv_gen TO message_tsv;
"""
)
)
# --- PART 2: chat_session table ---
# Step 1: Add nullable column (quick, minimal locking)
time.time()
connection.execute(
text(
"ALTER TABLE chat_session ADD COLUMN IF NOT EXISTS description_tsv tsvector"
)
)
# Step 2: Create function and trigger for new/updated rows - SPLIT INTO SEPARATE CALLS
connection.execute(
text(
"""
CREATE OR REPLACE FUNCTION update_chat_session_tsv()
RETURNS TRIGGER AS $$
BEGIN
NEW.description_tsv = to_tsvector('english', COALESCE(NEW.description, ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql
"""
)
)
# Create trigger in a separate execute call
connection.execute(
text(
"""
CREATE TRIGGER chat_session_tsv_trigger
BEFORE INSERT OR UPDATE ON chat_session
FOR EACH ROW EXECUTE FUNCTION update_chat_session_tsv()
"""
)
)
# Step 3: Update existing rows in batches using Python
time.time()
# Get the maximum ID to determine batch count
# Cast id to text for MAX function since it's a UUID
max_id_result = connection.execute(
text("SELECT COALESCE(MAX(id::text), '0') FROM chat_session")
).scalar()
max_id_result if max_id_result is not None else "0"
batch_size = 5000
batches = 0
# Get all IDs ordered to process in batches
rows = connection.execute(
text("SELECT id FROM chat_session ORDER BY id")
).fetchall()
total_rows = len(rows)
# Process in batches
for batch_num, batch_start in enumerate(range(0, total_rows, batch_size)):
batch_end = min(batch_start + batch_size, total_rows)
batch_ids = [row[0] for row in rows[batch_start:batch_end]]
if not batch_ids:
continue
# Use IN clause instead of BETWEEN for UUIDs
placeholders = ", ".join([f":id{i}" for i in range(len(batch_ids))])
params = {f"id{i}": id_val for i, id_val in enumerate(batch_ids)}
# Execute update for this batch
connection.execute(
text(
f"""
UPDATE chat_session
SET description_tsv = to_tsvector('english', COALESCE(description, ''))
WHERE id IN ({placeholders})
AND description_tsv IS NULL
"""
).bindparams(**params)
)
# Commit each batch
connection.execute(text("COMMIT"))
# Start a new transaction
connection.execute(text("BEGIN"))
batches += 1
# Final check for any remaining NULL values
connection.execute(
text(
"""
UPDATE chat_session SET description_tsv = to_tsvector('english', COALESCE(description, ''))
WHERE description_tsv IS NULL
"""
)
)
# Create GIN index concurrently
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session USING GIN (description_tsv)
"""
)
)
# After Final check for chat_session
# First drop the trigger as it won't be needed anymore
connection.execute(
text(
"""
DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session;
"""
)
)
connection.execute(
text(
"""
DROP FUNCTION IF EXISTS update_chat_session_tsv();
"""
)
)
# Add new generated column
time.time()
connection.execute(
text(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv_gen tsvector
GENERATED ALWAYS AS (to_tsvector('english', COALESCE(description, ''))) STORED;
"""
)
)
# Create new index on generated column
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv_gen
ON chat_session USING GIN (description_tsv_gen)
"""
)
)
# Drop old index and column
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;
"""
)
)
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
ALTER TABLE chat_session DROP COLUMN description_tsv;
"""
)
)
# Rename new column to old name
connection.execute(
text(
"""
ALTER TABLE chat_session RENAME COLUMN description_tsv_gen TO description_tsv;
"""
)
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)

View File

@@ -27,6 +27,7 @@ from ee.onyx.server.reporting.usage_export_api import router as usage_export_rou
from ee.onyx.server.saml import router as saml_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.tenants.router import router as new_router
from ee.onyx.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@@ -123,6 +124,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, user_group_router)
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, new_router)
include_router_with_global_prefix_prepended(application, query_history_router)
# EE only backend APIs
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -25,8 +25,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
) -> Response:
try:
if MULTI_TENANT:
print("Shold set tenant id")
tenant_id = await _get_tenant_id_from_request(request, logger)
print(f"Tenant id: {tenant_id}")
else:
print("Should not set tenant id")
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -64,8 +67,9 @@ async def _get_tenant_id_from_request(
try:
# Look up token data in Redis
print("I AM IN THIS FUNCTION 7")
token_data = await retrieve_auth_token_data_from_redis(request)
print("I AM IN THIS FUNCTION 8")
if not token_data:
logger.debug(

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import time
import uuid
import aiohttp # Async HTTP client
@@ -54,20 +55,26 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
) -> tuple[str, bool]:
"""Get existing tenant ID for an email or create a new tenant if none exists.
Returns:
tuple: (tenant_id, is_newly_created) - The tenant ID and a boolean indicating if it was newly created
"""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
return POSTGRES_DEFAULT_SCHEMA, False
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
is_newly_created = False
try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
is_newly_created = True
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
@@ -77,14 +84,22 @@ async def get_or_provision_tenant(
status_code=401, detail="User does not belong to an organization"
)
return tenant_id
return tenant_id, is_newly_created
async def create_tenant(email: str, referral_source: str | None = None) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
start_time = time.time()
await provision_tenant(tenant_id, email)
duration = time.time() - start_time
logger.error(
f"Tenant provisioning for {tenant_id} completed in {duration:.2f} seconds"
)
print(
f"Tenant provisioning for {tenant_id} completed in {duration:.2f} seconds"
)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
@@ -115,36 +130,23 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Await the Alembic migrations up to the specified revision
start_time = time.time()
await asyncio.to_thread(run_alembic_migrations, tenant_id, "465f78d9b7f9")
duration = time.time() - start_time
print(
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
)
logger.info(
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
)
logger.error(
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
)
# Add users to tenant - this is needed for authentication
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
@@ -349,3 +351,60 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
async def complete_tenant_setup(tenant_id: str) -> None:
"""Complete the tenant setup process after user creation.
This function handles the remaining steps of tenant provisioning after the initial
schema creation and user authentication:
1. Completes the remaining Alembic migrations
2. Configures default API keys
3. Sets up Onyx
4. Creates milestone record
"""
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
logger.debug(f"Completing setup for tenant {tenant_id}")
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Complete the remaining Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={},
db_session=db_session,
)
logger.info(f"Tenant setup completed for {tenant_id}")
except Exception as e:
logger.exception(f"Failed to complete tenant setup for {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to complete tenant setup: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -0,0 +1,40 @@
import logging
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.server.tenants.provisioning import complete_tenant_setup
from onyx.auth.users import optional_minimal_user
from onyx.db.models import MinimalUser
from shared_configs.contextvars import get_current_tenant_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tenants", tags=["tenants"])
class CompleteTenantSetupRequest(BaseModel):
email: str
@router.post("/complete-setup")
async def api_complete_tenant_setup(
request: CompleteTenantSetupRequest,
user: MinimalUser = Depends(optional_minimal_user),
) -> None:
"""Complete the tenant setup process for a user.
This endpoint is called from the frontend after user creation to complete
the tenant setup process (migrations, seeding, etc.).
"""
tenant_id = get_current_tenant_id()
try:
await complete_tenant_setup(tenant_id)
return {"status": "success"}
except Exception as e:
logger.error(f"Failed to complete tenant setup: {e}")
raise HTTPException(status_code=500, detail="Failed to complete tenant setup")

View File

@@ -14,8 +14,10 @@ from onyx.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
def run_alembic_migrations(schema_name: str, target_revision: str = "head") -> None:
logger.info(
f"Starting Alembic migrations for schema: {schema_name} to target: {target_revision}"
)
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -37,11 +39,10 @@ def run_alembic_migrations(schema_name: str) -> None:
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
command.upgrade(alembic_cfg, target_revision)
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
f"Alembic migrations completed successfully for schema: {schema_name} to target: {target_revision}"
)
except Exception as e:

View File

@@ -2,6 +2,7 @@ import json
import random
import secrets
import string
import time
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
@@ -13,6 +14,7 @@ from typing import Optional
from typing import Tuple
import jwt
import sqlalchemy.exc
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
@@ -90,6 +92,7 @@ from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import MinimalUser
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
@@ -186,6 +189,7 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
def verify_email_is_invited(email: str) -> None:
return None
whitelist = get_invited_users()
if not whitelist:
return
@@ -215,6 +219,7 @@ def verify_email_is_invited(email: str) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
return None
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -235,6 +240,36 @@ def verify_email_domain(email: str) -> None:
)
class SimpleUserManager(UUIDIDMixin, BaseUserManager[MinimalUser, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[MinimalUser, uuid.UUID]
async def get(self, id: uuid.UUID) -> MinimalUser:
"""Get a user by id, with error handling for partial database provisioning."""
try:
return await super().get(id)
except sqlalchemy.exc.ProgrammingError as e:
# Handle database schema mismatch during partial provisioning
if "column user.temperature_override_enabled does not exist" in str(e):
# Create a minimal user with just the required fields
# This is a temporary solution during partial provisioning
from onyx.db.models import MinimalUser
return MinimalUser(
id=id,
email="temp@example.com", # Will be replaced with actual data
hashed_password="",
is_active=True,
is_verified=True,
is_superuser=False,
role="BASIC",
)
# Re-raise other database errors
raise
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -247,8 +282,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)(user_email)
async with get_async_session_with_tenant(tenant_id) as db_session:
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
tenant_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
db_session, MinimalUser, OAuthAccount
)
user = await tenant_user_db.get_by_email(user_email)
else:
@@ -265,6 +300,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
start_time = time.time()
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
user_create.password, cast(schemas.UC, user_create)
@@ -277,7 +313,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else None
)
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -286,6 +322,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
print(f"Tenant ID: {tenant_id}, Is newly created: {is_newly_created}")
print("duration: ", time.time() - start_time)
user: User
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -309,7 +347,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
user_create.role = UserRole.BASIC
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
simple_tennat_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
db_session, MinimalUser, OAuthAccount
)
user = await SimpleUserManager(simple_tennat_user_db).create(
user_create, safe=safe, request=request
) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
@@ -374,7 +417,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
getattr(request.state, "referral_source", None) if request else None
)
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -511,7 +554,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -563,7 +606,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Your admin has not enabled this feature.",
)
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -587,8 +630,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
# Get tenant_id from mapping
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -673,6 +716,12 @@ async def get_user_manager(
yield UserManager(user_db)
async def get_minimal_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
) -> AsyncGenerator[SimpleUserManager, None]:
yield SimpleUserManager(user_db)
cookie_transport = CookieTransport(
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
cookie_secure=WEB_DOMAIN.startswith("https"),
@@ -684,6 +733,10 @@ def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
def get_minimal_redis_strategy() -> RedisStrategy:
return CustomRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
@@ -709,7 +762,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
async def write_token(self, user: User) -> str:
redis = await get_async_redis_connection()
tenant_id = await fetch_ee_implementation_or_noop(
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
@@ -749,14 +802,139 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
await redis.delete(f"{self.key_prefix}{token}")
class CustomRedisStrategy(RedisStrategy[MinimalUser, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
"""
def __init__(
self,
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
self.key_prefix = key_prefix
async def write_token(self, user: MinimalUser) -> str:
redis = await get_async_redis_connection()
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(email=user.email)
token_data = {
"sub": str(user.id),
"tenant_id": tenant_id,
}
token = secrets.token_urlsafe()
await redis.set(
f"{self.key_prefix}{token}",
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
async def read_token(
self,
token: Optional[str],
user_manager: BaseUserManager[MinimalUser, uuid.UUID],
) -> Optional[MinimalUser]:
redis = await get_async_redis_connection()
token_data_str = await redis.get(f"{self.key_prefix}{token}")
if not token_data_str:
return None
try:
token_data = json.loads(token_data_str)
user_id = token_data["sub"]
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
return None
async def destroy_token(self, token: str, user: MinimalUser) -> None:
"""Properly delete the token from async redis."""
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
# class CustomRedisStrategy(RedisStrategy[MinimalUser, uuid.UUID]):
# """Custom Redis strategy that handles database schema mismatches during partial provisioning."""
# def __init__(
# self,
# lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
# key_prefix: str = REDIS_AUTH_KEY_PREFIX,
# ):
# self.lifetime_seconds = lifetime_seconds
# self.key_prefix = key_prefix
# async def read_token(
# self, token: Optional[str], user_manager: BaseUserManager[MinimalUser, uuid.UUID]
# ) -> Optional[MinimalUser]:
# try:
# redis = await get_async_redis_connection()
# token_data_str = await redis.get(f"{self.key_prefix}{token}")
# if not token_data_str:
# return None
# try:
# token_data = json.loads(token_data_str)
# user_id = token_data["sub"]
# parsed_id = user_manager.parse_id(user_id)
# return await user_manager.get(parsed_id)
# except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
# return None
# except sqlalchemy.exc.ProgrammingError as e:
# # Handle database schema mismatch during partial provisioning
# if "column user.temperature_override_enabled does not exist" in str(e):
# # Return None to allow unauthenticated access during partial provisioning
# return None
# # Re-raise other database errors
# raise
# async def write_token(self, user: MinimalUser) -> str:
# redis = await get_async_redis_connection()
# token = generate_jwt(
# data={"sub": str(user.id)},
# secret=USER_AUTH_SECRET,
# lifetime_seconds=self.lifetime_seconds,
# )
# await redis.set(
# f"{self.key_prefix}{token}",
# json.dumps({"sub": str(user.id)}),
# ex=self.lifetime_seconds,
# )
# return token
# async def destroy_token(self, token: str, user: MinimalUser) -> None:
# """Properly delete the token from async redis."""
# redis = await get_async_redis_connection()
# await redis.delete(f"{self.key_prefix}{token}")
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
minimal_auth_backend = AuthenticationBackend(
name="redis",
transport=cookie_transport,
get_strategy=get_minimal_redis_strategy,
)
elif AUTH_BACKEND == AuthBackend.POSTGRES:
auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
minimal_auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
else:
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
@@ -803,6 +981,9 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
)
fastapi_minimal_users = FastAPIUserWithLogoutRouter[MinimalUser, uuid.UUID](
get_minimal_user_manager, [minimal_auth_backend]
)
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
# take care of that in `double_check_user` ourself. This is needed, since
@@ -811,6 +992,30 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
optional_minimal_user_dependency = fastapi_minimal_users.current_user(
active=True, optional=True
)
async def optional_minimal_user(
user: MinimalUser | None = Depends(optional_minimal_user_dependency),
) -> MinimalUser | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
print("I AM IN THIS FUNCTION 1")
try:
print("I AM IN THIS FUNCTION 2")
return user
except sqlalchemy.exc.ProgrammingError as e:
print("I AM IN THIS FUNCTION 3")
# Handle database schema mismatch during partial provisioning
if "column user.temperature_override_enabled does not exist" in str(e):
# Return None to allow unauthenticated access during partial provisioning
return None
# Re-raise other database errors
raise
async def optional_user_(
request: Request,
user: User | None,

View File

@@ -180,6 +180,7 @@ SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool:
print(f"Checking if {name} is valid")
return SCHEMA_NAME_REGEX.match(name) is not None
@@ -474,17 +475,27 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
print(f"Retrieved tenant_id: {tenant_id}")
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
print("Authentication error: User must authenticate")
raise BasicAuthenticationError(detail="User must authenticate")
engine = get_sqlalchemy_engine()
print("SQLAlchemy engine obtained")
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
print("MULTI_TENANT mode enabled")
if not is_valid_schema_name(tenant_id):
print(f"Invalid tenant ID detected: {tenant_id}")
raise HTTPException(status_code=400, detail="Invalid tenant ID")
print(f"Setting search_path to tenant schema: {tenant_id}")
session.execute(text(f'SET search_path = "{tenant_id}"'))
else:
print("MULTI_TENANT mode disabled")
yield session
print("Session yielded and closed")
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:

View File

@@ -2318,3 +2318,16 @@ class TenantAnonymousUserPath(Base):
anonymous_user_path: Mapped[str] = mapped_column(
String, nullable=False, unique=True
)
class AdditionalBase(DeclarativeBase):
__abstract__ = True
class MinimalUser(SQLAlchemyBaseUserTableUUID, AdditionalBase):
# oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
# "OAuthAccount", lazy="joined", cascade="all, delete-orphan"
# )
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)

View File

@@ -315,13 +315,20 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
try:
redis = await get_async_redis_connection()
redis_key = REDIS_AUTH_KEY_PREFIX + token
print("[Redis] Obtained async Redis connection")
redis_key = f"{REDIS_AUTH_KEY_PREFIX}{token}"
print(f"[Redis] Fetching token data for key: {redis_key}")
token_data_str = await redis.get(redis_key)
print(
f"[Redis] Retrieved token data: {'found' if token_data_str else 'not found'}"
)
if not token_data_str:
logger.debug(f"Token key {redis_key} not found or expired in Redis")
logger.debug(f"[Redis] Token key '{redis_key}' not found or expired")
return None
print("[Redis] Decoding token data")
print(f"[Redis] Token data: {token_data_str}")
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")

View File

@@ -10,6 +10,7 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import optional_minimal_user
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.server.onyx_api.ingestion import api_key_dep
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -114,6 +115,7 @@ def check_router_auth(
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == control_plane_dep
or depends_fn == optional_minimal_user
or depends_fn == current_cloud_superuser
):
found_auth = True

View File

@@ -112,6 +112,12 @@ class UserInfo(BaseModel):
)
class MinimalUserInfo(BaseModel):
id: str
email: str
is_active: bool
class UserByEmail(BaseModel):
user_email: str

View File

@@ -32,6 +32,7 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_minimal_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DEV_MODE
@@ -44,6 +45,7 @@ from onyx.db.api_key import is_api_key_email_address
from onyx.db.auth import get_total_users_count
from onyx.db.engine import get_session
from onyx.db.models import AccessToken
from onyx.db.models import MinimalUser
from onyx.db.models import User
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_all_users
@@ -55,6 +57,7 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import MinimalUserInfo
from onyx.server.manage.models import UserByEmail
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPreferences
@@ -534,6 +537,27 @@ def get_current_token_creation(
return None
@router.get("/me-info")
def verify_user_attempting_to_login(
request: Request,
user: MinimalUser | None = Depends(optional_minimal_user),
# db_session: Session = Depends(get_session),
) -> MinimalUserInfo:
# Check if the authentication cookie exists
# Print cookie names for debugging
cookie_names = list(request.cookies.keys())
logger.info(f"Available cookies: {cookie_names}")
if not request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME):
raise HTTPException(status_code=401, detail="User not found")
# print("I AM IN THIS FUNCTION 4")
# if user is None:
# print("I AM IN THIS FUNCTION 5")
# raise HTTPException(status_code=401, detail="User not found")
# print("I AM IN THIS FUNCTION 6")
return MinimalUserInfo(id="", email="", is_active=True)
@router.get("/me")
def verify_user_logged_in(
user: User | None = Depends(optional_user),

View File

@@ -43,6 +43,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
def get_or_generate_uuid() -> str:
return "hi"
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.

View File

@@ -8,8 +8,6 @@ import {
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
import Text from "@/components/ui/text";
import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import ReferralSourceSelector from "./ReferralSourceSelector";

View File

@@ -0,0 +1,242 @@
"use client";
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { MinimalUserInfo } from "@/lib/types";
import Text from "@/components/ui/text";
import { Logo } from "@/components/logo/Logo";
import { completeTenantSetup } from "@/lib/tenant";
import { useEffect, useState, useRef } from "react";
import {
Card,
CardContent,
CardHeader,
CardFooter,
} from "@/components/ui/card";
import { LoadingSpinner } from "@/app/chat/chat_search/LoadingSpinner";
import { Button } from "@/components/ui/button";
import { logout } from "@/lib/user";
import { FiLogOut } from "react-icons/fi";
export default function WaitingOnSetupPage({
minimalUserInfo,
}: {
minimalUserInfo: MinimalUserInfo;
}) {
const [progress, setProgress] = useState(0);
const [setupStage, setSetupStage] = useState<string>(
"Setting up your account"
);
const progressRef = useRef<number>(0);
const animationRef = useRef<number>();
const startTimeRef = useRef<number>(Date.now());
const [isReady, setIsReady] = useState(false);
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
// Setup stages that will cycle through during the loading process
const setupStages = [
"Setting up your account",
"Configuring workspace",
"Preparing resources",
"Setting up permissions",
"Finalizing setup",
];
// Function to poll the /api/me endpoint
const pollAccountStatus = async () => {
try {
const response = await fetch("/api/me");
if (response.status === 200) {
// Account is ready
setIsReady(true);
if (pollIntervalRef.current) {
clearInterval(pollIntervalRef.current);
}
return true;
}
} catch (error) {
console.error("Error polling account status:", error);
}
return false;
};
// Handle logout
const handleLogout = async () => {
try {
await logout();
window.location.href = "/auth/login";
} catch (error) {
console.error("Failed to logout:", error);
}
};
useEffect(() => {
// Animation setup for progress bar
let lastUpdateTime = 0;
const updateInterval = 100;
const normalAnimationDuration = 30000; // 30 seconds for normal animation
const acceleratedAnimationDuration = 1500; // 1.5 seconds for accelerated animation after ready
let currentStageIndex = 0;
let lastStageUpdateTime = Date.now();
const stageRotationInterval = 3000; // Rotate stages every 3 seconds
const animate = (timestamp: number) => {
const elapsedTime = timestamp - startTimeRef.current;
const now = Date.now();
// Calculate progress using different curves based on ready status
const maxProgress = 99;
let progress;
if (isReady) {
// Accelerate to 100% when account is ready
progress =
maxProgress +
(100 - maxProgress) *
((now - startTimeRef.current) / acceleratedAnimationDuration);
if (progress >= 100) progress = 100;
} else {
// Slower progress when still waiting
progress =
maxProgress * (1 - Math.exp(-elapsedTime / normalAnimationDuration));
}
// Update progress if enough time has passed
if (timestamp - lastUpdateTime > updateInterval) {
progressRef.current = progress;
setProgress(Math.round(progress * 10) / 10);
// Cycle through setup stages
if (now - lastStageUpdateTime > stageRotationInterval && !isReady) {
currentStageIndex = (currentStageIndex + 1) % setupStages.length;
setSetupStage(setupStages[currentStageIndex]);
lastStageUpdateTime = now;
}
lastUpdateTime = timestamp;
}
// Continue animation if not completed
if (progress < 100) {
animationRef.current = requestAnimationFrame(animate);
} else if (progress >= 100) {
// Redirect when progress reaches 100%
setSetupStage("Setup complete!");
setTimeout(() => {
window.location.href = "/chat";
}, 500);
}
};
// Start animation
startTimeRef.current = performance.now();
animationRef.current = requestAnimationFrame(animate);
// Start polling the /api/me endpoint
pollIntervalRef.current = setInterval(async () => {
const ready = await pollAccountStatus();
if (ready) {
// If ready, we'll let the animation handle the redirect
console.log("Account is ready!");
}
}, 2000); // Poll every 2 seconds
// Attempt to complete tenant setup
// completeTenantSetup(minimalUserInfo.email).catch((error) => {
// console.error("Failed to complete tenant setup:", error);
// });
// Cleanup function
return () => {
if (animationRef.current) {
cancelAnimationFrame(animationRef.current);
}
if (pollIntervalRef.current) {
clearInterval(pollIntervalRef.current);
}
};
}, [isReady, minimalUserInfo.email]);
useEffect(() => {
completeTenantSetup(minimalUserInfo.email).catch((error) => {
console.error("Failed to complete tenant setup:", error);
});
}, []);
return (
<main className="min-h-screen bg-gradient-to-b from-white to-neutral-50 dark:from-neutral-900 dark:to-neutral-950">
<div className="absolute top-0 w-full">
<HealthCheckBanner />
</div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div className="w-full max-w-md">
<div className="flex flex-col items-center mb-8">
<Logo height={80} width={80} className="mx-auto w-fit mb-6" />
<h1 className="text-2xl font-bold text-neutral-900 dark:text-white">
Account Setup
</h1>
</div>
<Card className="border-neutral-200 dark:border-neutral-800 shadow-lg">
<CardHeader className="pb-0">
<div className="flex items-center justify-between">
<div className="flex items-center space-x-3">
<div className="relative">
<LoadingSpinner size="medium" className="text-primary" />
</div>
<h2 className="text-lg font-semibold text-neutral-800 dark:text-neutral-200">
{setupStage}
</h2>
</div>
<span className="text-sm font-medium text-neutral-500 dark:text-neutral-400">
{progress}%
</span>
</div>
</CardHeader>
<CardContent className="pt-4">
{/* Progress bar */}
<div className="w-full h-2 bg-neutral-200 dark:bg-neutral-700 rounded-full mb-6 overflow-hidden">
<div
className="h-full bg-blue-600 dark:bg-blue-500 rounded-full transition-all duration-300 ease-out"
style={{ width: `${progress}%` }}
/>
</div>
<div className="space-y-4">
<div className="flex flex-col space-y-1">
<Text className="text-neutral-800 dark:text-neutral-200 font-medium">
Welcome,{" "}
<span className="font-semibold">
{minimalUserInfo?.email}
</span>
</Text>
<Text className="text-neutral-600 dark:text-neutral-400 text-sm">
We're setting up your account. This may take a few moments.
</Text>
</div>
<div className="bg-blue-50 dark:bg-blue-900/20 rounded-lg p-4 border border-blue-100 dark:border-blue-800">
<Text className="text-sm text-blue-700 dark:text-blue-300">
You'll be redirected automatically when your account is
ready. If you're not redirected within a minute, please
refresh the page.
</Text>
</div>
</div>
</CardContent>
<CardFooter className="flex justify-end pt-4">
<Button
variant="outline"
size="sm"
onClick={handleLogout}
className="text-neutral-600 dark:text-neutral-300"
>
<FiLogOut className="mr-1" />
Logout
</Button>
</CardFooter>
</Card>
</div>
</div>
</main>
);
}

View File

@@ -20,7 +20,7 @@ import {
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
import { AppProvider } from "@/components/context/AppProvider";
import { PHProvider } from "./providers";
import { getCurrentUserSS } from "@/lib/userSS";
import { getCurrentUserSS, getMinimalUserInfoSS } from "@/lib/userSS";
import { Suspense } from "react";
import PostHogPageView from "./PostHogPageView";
import Script from "next/script";
@@ -30,6 +30,8 @@ import { ThemeProvider } from "next-themes";
import CloudError from "@/components/errorPages/CloudErrorPage";
import Error from "@/components/errorPages/ErrorPage";
import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage";
import CompleteTenantSetupPage from "./auth/waiting-on-setup/WaitingOnSetup";
import WaitingOnSetupPage from "./auth/waiting-on-setup/WaitingOnSetup";
const inter = Inter({
subsets: ["latin"],
@@ -70,12 +72,17 @@ export default async function RootLayout({
}: {
children: React.ReactNode;
}) {
const [combinedSettings, assistantsData, user] = await Promise.all([
fetchSettingsSS(),
fetchAssistantData(),
getCurrentUserSS(),
]);
const [combinedSettings, assistantsData, user, minimalUserInfo] =
await Promise.all([
fetchSettingsSS(),
fetchAssistantData(),
getCurrentUserSS(),
getMinimalUserInfoSS(),
]);
// if (!user && minimalUserInfo) {
// return <CompleteTenantSetupPage />;
// }
const productGating =
combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE;
@@ -135,6 +142,11 @@ export default async function RootLayout({
if (productGating === ApplicationStatus.GATED_ACCESS) {
return getPageContent(<AccessRestrictedPage />);
}
if (!user && minimalUserInfo) {
return getPageContent(
<WaitingOnSetupPage minimalUserInfo={minimalUserInfo} />
);
}
if (!combinedSettings) {
return getPageContent(

14
web/src/lib/tenant.ts Normal file
View File

@@ -0,0 +1,14 @@
export async function completeTenantSetup(email: string): Promise<void> {
const response = await fetch(`/api/tenants/complete-setup`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email }),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Failed to complete tenant setup: ${errorText}`);
}
}

View File

@@ -15,6 +15,12 @@ interface UserPreferences {
temperature_override_enabled: boolean;
}
export interface MinimalUserInfo {
id: string;
email: string;
is_active: boolean;
}
export enum UserRole {
LIMITED = "limited",
BASIC = "basic",

View File

@@ -1,5 +1,5 @@
import { cookies } from "next/headers";
import { User } from "./types";
import { MinimalUserInfo, User } from "./types";
import { buildUrl } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
@@ -162,6 +162,30 @@ export const logoutSS = async (
}
}
};
export const getMinimalUserInfoSS =
async (): Promise<MinimalUserInfo | null> => {
try {
const response = await fetch(buildUrl("/me-info"), {
credentials: "include",
headers: {
cookie: (await cookies())
.getAll()
.map((cookie) => `${cookie.name}=${cookie.value}`)
.join("; "),
},
next: { revalidate: 0 },
});
if (!response.ok) {
console.error("Failed to fetch minimal user info");
return null;
}
return await response.json();
} catch (e) {
console.log(`Error fetching minimal user info: ${e}`);
return null;
}
};
export const getCurrentUserSS = async (): Promise<User | null> => {
try {