mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 10:15:46 +00:00
Compare commits
3 Commits
github_lis
...
feature/mu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ba16d31fd | ||
|
|
8a5a6d3c91 | ||
|
|
09e4e73ba6 |
@@ -1,125 +0,0 @@
|
||||
"""Update GitHub connector repo_name to repositories
|
||||
|
||||
Revision ID: 3934b1bc7b62
|
||||
Revises: b7c2b63c4a03
|
||||
Create Date: 2025-03-05 10:50:30.516962
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import logging
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3934b1bc7b62"
|
||||
down_revision = "b7c2b63c4a03"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all GitHub connectors
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
updated_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
logger.warning(f"Connector {connector_id} has no config, skipping")
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repo_name" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repositories instead of repo_name
|
||||
new_config = dict(config)
|
||||
repo_name_value = new_config.pop("repo_name")
|
||||
new_config["repositories"] = repo_name_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating connector {connector_id}: {str(e)}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
logger.debug(
|
||||
"Starting rollback of GitHub connectors from repositories to repo_name"
|
||||
)
|
||||
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
|
||||
|
||||
# Revert each GitHub connector to use repo_name instead of repositories
|
||||
reverted_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repositories" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repo_name instead of repositories
|
||||
new_config = dict(config)
|
||||
repositories_value = new_config.pop("repositories")
|
||||
new_config["repo_name"] = repositories_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"new_config": json.dumps(new_config), "connector_id": connector_id},
|
||||
)
|
||||
reverted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error reverting connector {connector_id}: {str(e)}")
|
||||
@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -27,6 +27,7 @@ from ee.onyx.server.reporting.usage_export_api import router as usage_export_rou
|
||||
from ee.onyx.server.saml import router as saml_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.tenants.router import router as new_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
@@ -123,12 +124,13 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
# Analytics endpoints
|
||||
include_router_with_global_prefix_prepended(application, analytics_router)
|
||||
include_router_with_global_prefix_prepended(application, new_router)
|
||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -25,8 +25,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
) -> Response:
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
print("Shold set tenant id")
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
print(f"Tenant id: {tenant_id}")
|
||||
else:
|
||||
print("Should not set tenant id")
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -64,8 +67,9 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
print("I AM IN THIS FUNCTION 7")
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
print("I AM IN THIS FUNCTION 8")
|
||||
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
|
||||
@@ -80,7 +80,6 @@ class ConfluenceCloudOAuth:
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"read:content-details:confluence%20" # for permission sync
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
|
||||
0
backend/ee/onyx/server/tenants/initial_models.py
Normal file
0
backend/ee/onyx/server/tenants/initial_models.py
Normal file
@@ -48,5 +48,4 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
@@ -54,24 +55,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
) -> tuple[str, bool]:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
|
||||
Returns:
|
||||
tuple: (tenant_id, is_newly_created) - The tenant ID and a boolean indicating if it was newly created
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
return POSTGRES_DEFAULT_SCHEMA, False
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
is_newly_created = False
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
is_newly_created = True
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
@@ -81,14 +84,22 @@ async def get_or_provision_tenant(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
return tenant_id, is_newly_created
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
start_time = time.time()
|
||||
await provision_tenant(tenant_id, email)
|
||||
duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Tenant provisioning for {tenant_id} completed in {duration:.2f} seconds"
|
||||
)
|
||||
print(
|
||||
f"Tenant provisioning for {tenant_id} completed in {duration:.2f} seconds"
|
||||
)
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
@@ -119,36 +130,23 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
# Await the Alembic migrations up to the specified revision
|
||||
start_time = time.time()
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id, "465f78d9b7f9")
|
||||
duration = time.time() - start_time
|
||||
print(
|
||||
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
|
||||
)
|
||||
logger.info(
|
||||
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
|
||||
)
|
||||
logger.error(
|
||||
f"Alembic migrations for tenant {tenant_id} completed in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
# Add users to tenant - this is needed for authentication
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
@@ -353,3 +351,60 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
async def complete_tenant_setup(tenant_id: str) -> None:
|
||||
"""Complete the tenant setup process after user creation.
|
||||
|
||||
This function handles the remaining steps of tenant provisioning after the initial
|
||||
schema creation and user authentication:
|
||||
1. Completes the remaining Alembic migrations
|
||||
2. Configures default API keys
|
||||
3. Sets up Onyx
|
||||
4. Creates milestone record
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
logger.debug(f"Completing setup for tenant {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Complete the remaining Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"Tenant setup completed for {tenant_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to complete tenant setup for {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete tenant setup: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
40
backend/ee/onyx/server/tenants/router.py
Normal file
40
backend/ee/onyx/server/tenants/router.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import complete_tenant_setup
|
||||
from onyx.auth.users import optional_minimal_user
|
||||
from onyx.db.models import MinimalUser
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/tenants", tags=["tenants"])
|
||||
|
||||
|
||||
class CompleteTenantSetupRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
@router.post("/complete-setup")
|
||||
async def api_complete_tenant_setup(
|
||||
request: CompleteTenantSetupRequest,
|
||||
user: MinimalUser = Depends(optional_minimal_user),
|
||||
) -> None:
|
||||
"""Complete the tenant setup process for a user.
|
||||
|
||||
This endpoint is called from the frontend after user creation to complete
|
||||
the tenant setup process (migrations, seeding, etc.).
|
||||
"""
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
await complete_tenant_setup(tenant_id)
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete tenant setup: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to complete tenant setup")
|
||||
@@ -14,8 +14,10 @@ from onyx.db.engine import get_sqlalchemy_engine
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
def run_alembic_migrations(schema_name: str, target_revision: str = "head") -> None:
|
||||
logger.info(
|
||||
f"Starting Alembic migrations for schema: {schema_name} to target: {target_revision}"
|
||||
)
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -37,11 +39,10 @@ def run_alembic_migrations(schema_name: str) -> None:
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
command.upgrade(alembic_cfg, target_revision)
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
f"Alembic migrations completed successfully for schema: {schema_name} to target: {target_revision}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -13,6 +14,7 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
import sqlalchemy.exc
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import EmailUndeliverableError
|
||||
from email_validator import validate_email
|
||||
@@ -90,6 +92,7 @@ from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import MinimalUser
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -186,6 +189,7 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
return None
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
@@ -215,6 +219,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
return None
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@@ -235,6 +240,36 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
class SimpleUserManager(UUIDIDMixin, BaseUserManager[MinimalUser, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
user_db: SQLAlchemyUserDatabase[MinimalUser, uuid.UUID]
|
||||
|
||||
async def get(self, id: uuid.UUID) -> MinimalUser:
|
||||
"""Get a user by id, with error handling for partial database provisioning."""
|
||||
try:
|
||||
return await super().get(id)
|
||||
except sqlalchemy.exc.ProgrammingError as e:
|
||||
# Handle database schema mismatch during partial provisioning
|
||||
if "column user.temperature_override_enabled does not exist" in str(e):
|
||||
# Create a minimal user with just the required fields
|
||||
# This is a temporary solution during partial provisioning
|
||||
from onyx.db.models import MinimalUser
|
||||
|
||||
return MinimalUser(
|
||||
id=id,
|
||||
email="temp@example.com", # Will be replaced with actual data
|
||||
hashed_password="",
|
||||
is_active=True,
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role="BASIC",
|
||||
)
|
||||
# Re-raise other database errors
|
||||
raise
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -247,8 +282,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)(user_email)
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
|
||||
db_session, MinimalUser, OAuthAccount
|
||||
)
|
||||
user = await tenant_user_db.get_by_email(user_email)
|
||||
else:
|
||||
@@ -265,6 +300,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
start_time = time.time()
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
@@ -277,7 +313,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@@ -286,6 +322,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
print(f"Tenant ID: {tenant_id}, Is newly created: {is_newly_created}")
|
||||
print("duration: ", time.time() - start_time)
|
||||
user: User
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -309,7 +347,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
simple_tennat_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
|
||||
db_session, MinimalUser, OAuthAccount
|
||||
)
|
||||
user = await SimpleUserManager(simple_tennat_user_db).create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
@@ -374,7 +417,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
getattr(request.state, "referral_source", None) if request else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@@ -511,7 +554,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@@ -563,7 +606,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enabled this feature.",
|
||||
)
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@@ -587,20 +630,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
None,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"User attempted to login with invalid credentials: {str(e)}"
|
||||
)
|
||||
|
||||
# Get tenant_id from mapping
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
@@ -679,6 +716,12 @@ async def get_user_manager(
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
async def get_minimal_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
) -> AsyncGenerator[SimpleUserManager, None]:
|
||||
yield SimpleUserManager(user_db)
|
||||
|
||||
|
||||
cookie_transport = CookieTransport(
|
||||
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
cookie_secure=WEB_DOMAIN.startswith("https"),
|
||||
@@ -690,6 +733,10 @@ def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_minimal_redis_strategy() -> RedisStrategy:
|
||||
return CustomRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
@@ -715,7 +762,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@@ -755,14 +802,139 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
class CustomRedisStrategy(RedisStrategy[MinimalUser, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def write_token(self, user: MinimalUser) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
token = secrets.token_urlsafe()
|
||||
await redis.set(
|
||||
f"{self.key_prefix}{token}",
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
return token
|
||||
|
||||
async def read_token(
|
||||
self,
|
||||
token: Optional[str],
|
||||
user_manager: BaseUserManager[MinimalUser, uuid.UUID],
|
||||
) -> Optional[MinimalUser]:
|
||||
redis = await get_async_redis_connection()
|
||||
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = json.loads(token_data_str)
|
||||
user_id = token_data["sub"]
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
return None
|
||||
|
||||
async def destroy_token(self, token: str, user: MinimalUser) -> None:
|
||||
"""Properly delete the token from async redis."""
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
# class CustomRedisStrategy(RedisStrategy[MinimalUser, uuid.UUID]):
|
||||
# """Custom Redis strategy that handles database schema mismatches during partial provisioning."""
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
|
||||
# key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
# ):
|
||||
# self.lifetime_seconds = lifetime_seconds
|
||||
# self.key_prefix = key_prefix
|
||||
|
||||
# async def read_token(
|
||||
# self, token: Optional[str], user_manager: BaseUserManager[MinimalUser, uuid.UUID]
|
||||
# ) -> Optional[MinimalUser]:
|
||||
# try:
|
||||
# redis = await get_async_redis_connection()
|
||||
# token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
# if not token_data_str:
|
||||
# return None
|
||||
|
||||
# try:
|
||||
# token_data = json.loads(token_data_str)
|
||||
# user_id = token_data["sub"]
|
||||
# parsed_id = user_manager.parse_id(user_id)
|
||||
# return await user_manager.get(parsed_id)
|
||||
# except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
# return None
|
||||
# except sqlalchemy.exc.ProgrammingError as e:
|
||||
# # Handle database schema mismatch during partial provisioning
|
||||
# if "column user.temperature_override_enabled does not exist" in str(e):
|
||||
# # Return None to allow unauthenticated access during partial provisioning
|
||||
# return None
|
||||
# # Re-raise other database errors
|
||||
# raise
|
||||
|
||||
# async def write_token(self, user: MinimalUser) -> str:
|
||||
# redis = await get_async_redis_connection()
|
||||
|
||||
# token = generate_jwt(
|
||||
# data={"sub": str(user.id)},
|
||||
# secret=USER_AUTH_SECRET,
|
||||
# lifetime_seconds=self.lifetime_seconds,
|
||||
# )
|
||||
|
||||
# await redis.set(
|
||||
# f"{self.key_prefix}{token}",
|
||||
# json.dumps({"sub": str(user.id)}),
|
||||
# ex=self.lifetime_seconds,
|
||||
# )
|
||||
|
||||
# return token
|
||||
|
||||
# async def destroy_token(self, token: str, user: MinimalUser) -> None:
|
||||
# """Properly delete the token from async redis."""
|
||||
# redis = await get_async_redis_connection()
|
||||
# await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
minimal_auth_backend = AuthenticationBackend(
|
||||
name="redis",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_minimal_redis_strategy,
|
||||
)
|
||||
elif AUTH_BACKEND == AuthBackend.POSTGRES:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
|
||||
)
|
||||
minimal_auth_backend = AuthenticationBackend(
|
||||
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
|
||||
|
||||
@@ -809,6 +981,9 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
)
|
||||
|
||||
fastapi_minimal_users = FastAPIUserWithLogoutRouter[MinimalUser, uuid.UUID](
|
||||
get_minimal_user_manager, [minimal_auth_backend]
|
||||
)
|
||||
|
||||
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
|
||||
# take care of that in `double_check_user` ourself. This is needed, since
|
||||
@@ -817,6 +992,30 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
optional_minimal_user_dependency = fastapi_minimal_users.current_user(
|
||||
active=True, optional=True
|
||||
)
|
||||
|
||||
|
||||
async def optional_minimal_user(
|
||||
user: MinimalUser | None = Depends(optional_minimal_user_dependency),
|
||||
) -> MinimalUser | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
print("I AM IN THIS FUNCTION 1")
|
||||
try:
|
||||
print("I AM IN THIS FUNCTION 2")
|
||||
return user
|
||||
except sqlalchemy.exc.ProgrammingError as e:
|
||||
print("I AM IN THIS FUNCTION 3")
|
||||
# Handle database schema mismatch during partial provisioning
|
||||
if "column user.temperature_override_enabled does not exist" in str(e):
|
||||
# Return None to allow unauthenticated access during partial provisioning
|
||||
return None
|
||||
# Re-raise other database errors
|
||||
raise
|
||||
|
||||
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
|
||||
@@ -240,7 +240,7 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
|
||||
@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repositories: str | None = None,
|
||||
repo_name: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repositories = repositories
|
||||
self.repo_name = repo_name
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
@@ -157,42 +157,11 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_github_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
"""Get specific repositories based on comma-separated repo_name string."""
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
repos = []
|
||||
# Split repo_name by comma and strip whitespace
|
||||
repo_names = [
|
||||
name.strip() for name in (cast(str, self.repositories)).split(",")
|
||||
]
|
||||
|
||||
for repo_name in repo_names:
|
||||
if repo_name: # Skip empty strings
|
||||
try:
|
||||
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
|
||||
repos.append(repo)
|
||||
except GithubException as e:
|
||||
logger.warning(
|
||||
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
|
||||
)
|
||||
|
||||
return repos
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
@@ -220,17 +189,11 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
repos = (
|
||||
[self._get_github_repo(self.github_client)]
|
||||
if self.repo_name
|
||||
else self._get_all_repos(self.github_client)
|
||||
)
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
@@ -305,48 +268,11 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repo_names = [name.strip() for name in self.repositories.split(",")]
|
||||
if not repo_names:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: No valid repository names provided."
|
||||
)
|
||||
|
||||
# Validate at least one repository exists and is accessible
|
||||
valid_repos = False
|
||||
validation_errors = []
|
||||
|
||||
for repo_name in repo_names:
|
||||
if not repo_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
valid_repos = True
|
||||
# If at least one repo is valid, we can proceed
|
||||
break
|
||||
except GithubException as e:
|
||||
validation_errors.append(
|
||||
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
|
||||
)
|
||||
|
||||
if not valid_repos:
|
||||
error_msg = (
|
||||
"None of the specified repositories could be accessed: "
|
||||
)
|
||||
error_msg += ", ".join(validation_errors)
|
||||
raise ConnectorValidationError(error_msg)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
if self.repo_name:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
try:
|
||||
@@ -372,15 +298,10 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
raise ConnectorValidationError(
|
||||
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
if self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
@@ -389,7 +310,6 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
@@ -401,7 +321,7 @@ if __name__ == "__main__":
|
||||
|
||||
connector = GithubConnector(
|
||||
repo_owner=os.environ["REPO_OWNER"],
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
repo_name=os.environ["REPO_NAME"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
|
||||
@@ -180,6 +180,7 @@ SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
def is_valid_schema_name(name: str) -> bool:
|
||||
print(f"Checking if {name} is valid")
|
||||
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||
|
||||
|
||||
@@ -474,17 +475,27 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
print(f"Retrieved tenant_id: {tenant_id}")
|
||||
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
print("Authentication error: User must authenticate")
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
print("SQLAlchemy engine obtained")
|
||||
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
print("MULTI_TENANT mode enabled")
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
print(f"Invalid tenant ID detected: {tenant_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
print(f"Setting search_path to tenant schema: {tenant_id}")
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
else:
|
||||
print("MULTI_TENANT mode disabled")
|
||||
yield session
|
||||
print("Session yielded and closed")
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
@@ -2318,3 +2318,16 @@ class TenantAnonymousUserPath(Base):
|
||||
anonymous_user_path: Mapped[str] = mapped_column(
|
||||
String, nullable=False, unique=True
|
||||
)
|
||||
|
||||
|
||||
class AdditionalBase(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class MinimalUser(SQLAlchemyBaseUserTableUUID, AdditionalBase):
|
||||
# oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
# "OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
||||
# )
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
|
||||
@@ -464,29 +464,12 @@ def index_doc_batch(
|
||||
),
|
||||
)
|
||||
|
||||
all_returned_doc_ids = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
successful_doc_ids = {record.document_id for record in insertion_records}
|
||||
if successful_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"Successful IDs: {successful_doc_ids}"
|
||||
)
|
||||
|
||||
last_modified_ids = []
|
||||
|
||||
@@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as standard_oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@@ -323,7 +322,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -315,13 +315,20 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
print("[Redis] Obtained async Redis connection")
|
||||
redis_key = f"{REDIS_AUTH_KEY_PREFIX}{token}"
|
||||
print(f"[Redis] Fetching token data for key: {redis_key}")
|
||||
token_data_str = await redis.get(redis_key)
|
||||
print(
|
||||
f"[Redis] Retrieved token data: {'found' if token_data_str else 'not found'}"
|
||||
)
|
||||
|
||||
if not token_data_str:
|
||||
logger.debug(f"Token key {redis_key} not found or expired in Redis")
|
||||
logger.debug(f"[Redis] Token key '{redis_key}' not found or expired")
|
||||
return None
|
||||
|
||||
print("[Redis] Decoding token data")
|
||||
print(f"[Redis] Token data: {token_data_str}")
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import optional_minimal_user
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.server.onyx_api.ingestion import api_key_dep
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -114,6 +115,7 @@ def check_router_auth(
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accesssible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == optional_minimal_user
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
found_auth = True
|
||||
|
||||
@@ -112,6 +112,12 @@ class UserInfo(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserInfo(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
is_active: bool
|
||||
|
||||
|
||||
class UserByEmail(BaseModel):
|
||||
user_email: str
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_minimal_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
@@ -44,6 +45,7 @@ from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_total_users_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import MinimalUser
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_all_users
|
||||
@@ -55,6 +57,7 @@ from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import MinimalUserInfo
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.manage.models import UserPreferences
|
||||
@@ -534,6 +537,27 @@ def get_current_token_creation(
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/me-info")
|
||||
def verify_user_attempting_to_login(
|
||||
request: Request,
|
||||
user: MinimalUser | None = Depends(optional_minimal_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
) -> MinimalUserInfo:
|
||||
# Check if the authentication cookie exists
|
||||
# Print cookie names for debugging
|
||||
cookie_names = list(request.cookies.keys())
|
||||
logger.info(f"Available cookies: {cookie_names}")
|
||||
|
||||
if not request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME):
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
# print("I AM IN THIS FUNCTION 4")
|
||||
# if user is None:
|
||||
# print("I AM IN THIS FUNCTION 5")
|
||||
# raise HTTPException(status_code=401, detail="User not found")
|
||||
# print("I AM IN THIS FUNCTION 6")
|
||||
return MinimalUserInfo(id="", email="", is_active=True)
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(
|
||||
user: User | None = Depends(optional_user),
|
||||
|
||||
@@ -43,6 +43,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
|
||||
|
||||
def get_or_generate_uuid() -> str:
|
||||
return "hi"
|
||||
# TODO: split out the whole "instance UUID" generation logic into a separate
|
||||
# utility function. Telemetry should not be aware at all of how the UUID is
|
||||
# generated/stored.
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_confluence_connector_basic(
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 2
|
||||
assert len(doc_batch) == 3
|
||||
|
||||
page_within_a_page_doc: Document | None = None
|
||||
page_doc: Document | None = None
|
||||
|
||||
@@ -80,13 +80,3 @@ prod cluster**
|
||||
- `kubectl delete -f .`
|
||||
- To not delete the persistent volumes (Document indexes and Users), specify the specific `.yaml` files instead of
|
||||
`.` without specifying delete on persistent-volumes.yaml.
|
||||
|
||||
### Using Helm to deploy to an existing cluster
|
||||
|
||||
Onyx has a helm chart that is convenient to install all services to an existing Kubernetes cluster. To install:
|
||||
|
||||
* Currently the helm chart is not published so to install, clone the repo.
|
||||
* Configure access to the cluster via kubectl. Ensure the kubectl context is set to the cluster that you want to use
|
||||
* The default secrets, environment variables and other service level configuration are stored in `deployment/helm/charts/onyx/values.yml`. You may create another `override.yml`
|
||||
* `cd deployment/helm/charts/onyx` and run `helm install onyx -n onyx -f override.yaml .`. This will install onyx on the cluster under the `onyx` namespace.
|
||||
* Check the status of the deploy using `kubectl get pods -n onyx`
|
||||
@@ -1,27 +0,0 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-ingress-api
|
||||
annotations:
|
||||
kubernetes.io/ingress.class: nginx
|
||||
nginx.ingress.kubernetes.io/rewrite-target: /$2
|
||||
nginx.ingress.kubernetes.io/use-regex: "true"
|
||||
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
spec:
|
||||
rules:
|
||||
- host: {{ .Values.ingress.api.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /api(/|$)(.*)
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "onyx-stack.fullname" . }}-api-service
|
||||
port:
|
||||
number: {{ .Values.api.service.servicePort }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.api.host }}
|
||||
secretName: {{ include "onyx-stack.fullname" . }}-ingress-api-tls
|
||||
{{- end }}
|
||||
@@ -1,26 +0,0 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-ingress-webserver
|
||||
annotations:
|
||||
kubernetes.io/ingress.class: nginx
|
||||
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
kubernetes.io/tls-acme: "true"
|
||||
spec:
|
||||
rules:
|
||||
- host: {{ .Values.ingress.webserver.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "onyx-stack.fullname" . }}-webserver
|
||||
port:
|
||||
number: {{ .Values.webserver.service.servicePort }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.webserver.host }}
|
||||
secretName: {{ include "onyx-stack.fullname" . }}-ingress-webserver-tls
|
||||
{{- end }}
|
||||
@@ -1,20 +0,0 @@
|
||||
{{- if .Values.letsencrypt.enabled -}}
|
||||
apiVersion: cert-manager.io/v1
|
||||
kind: ClusterIssuer
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
spec:
|
||||
acme:
|
||||
# The ACME server URL
|
||||
server: https://acme-v02.api.letsencrypt.org/directory
|
||||
# Email address used for ACME registration
|
||||
email: {{ .Values.letsencrypt.email }}
|
||||
# Name of a secret used to store the ACME account private key
|
||||
privateKeySecretRef:
|
||||
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
# Enable the HTTP-01 challenge provider
|
||||
solvers:
|
||||
- http01:
|
||||
ingress:
|
||||
class: nginx
|
||||
{{- end }}
|
||||
@@ -376,17 +376,22 @@ redis:
|
||||
existingSecret: onyx-secrets
|
||||
existingSecretPasswordKey: redis_password
|
||||
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
api:
|
||||
host: onyx.local
|
||||
webserver:
|
||||
host: onyx.local
|
||||
# ingress:
|
||||
# enabled: false
|
||||
# className: ""
|
||||
# annotations: {}
|
||||
# # kubernetes.io/ingress.class: nginx
|
||||
# # kubernetes.io/tls-acme: "true"
|
||||
# hosts:
|
||||
# - host: chart-example.local
|
||||
# paths:
|
||||
# - path: /
|
||||
# pathType: ImplementationSpecific
|
||||
# tls: []
|
||||
# # - secretName: chart-example-tls
|
||||
# # hosts:
|
||||
# # - chart-example.local
|
||||
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
email: "abc@abc.com"
|
||||
|
||||
auth:
|
||||
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets
|
||||
|
||||
@@ -290,24 +290,21 @@ export function SettingsForm() {
|
||||
id="chatRetentionInput"
|
||||
placeholder="Infinite Retention"
|
||||
/>
|
||||
<div className="mr-auto flex gap-2">
|
||||
<Button
|
||||
onClick={handleSetChatRetention}
|
||||
variant="submit"
|
||||
size="sm"
|
||||
className="mr-auto"
|
||||
>
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleClearChatRetention}
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="mr-auto"
|
||||
>
|
||||
Retain All
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
onClick={handleSetChatRetention}
|
||||
variant="submit"
|
||||
size="sm"
|
||||
className="mr-3"
|
||||
>
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleClearChatRetention}
|
||||
variant="default"
|
||||
size="sm"
|
||||
>
|
||||
Retain All
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
@@ -61,7 +61,6 @@ export function EmailPasswordForm({
|
||||
|
||||
if (!response.ok) {
|
||||
setIsWorking(false);
|
||||
|
||||
const errorDetail = (await response.json()).detail;
|
||||
let errorMsg = "Unknown error";
|
||||
if (typeof errorDetail === "object" && errorDetail.reason) {
|
||||
@@ -97,13 +96,12 @@ export function EmailPasswordForm({
|
||||
} else {
|
||||
setIsWorking(false);
|
||||
const errorDetail = (await loginResponse.json()).detail;
|
||||
|
||||
let errorMsg = "Unknown error";
|
||||
if (errorDetail === "LOGIN_BAD_CREDENTIALS") {
|
||||
errorMsg = "Invalid email or password";
|
||||
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
|
||||
errorMsg = "Create an account to set a password";
|
||||
} else if (typeof errorDetail === "string") {
|
||||
errorMsg = errorDetail;
|
||||
}
|
||||
if (loginResponse.status === 429) {
|
||||
errorMsg = "Too many requests. Please try again later.";
|
||||
|
||||
@@ -8,8 +8,6 @@ import {
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||
import Text from "@/components/ui/text";
|
||||
import Link from "next/link";
|
||||
import { SignInButton } from "../login/SignInButton";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
import ReferralSourceSelector from "./ReferralSourceSelector";
|
||||
|
||||
242
web/src/app/auth/waiting-on-setup/WaitingOnSetup.tsx
Normal file
242
web/src/app/auth/waiting-on-setup/WaitingOnSetup.tsx
Normal file
@@ -0,0 +1,242 @@
|
||||
"use client";
|
||||
|
||||
import { redirect } from "next/navigation";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { MinimalUserInfo } from "@/lib/types";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { completeTenantSetup } from "@/lib/tenant";
|
||||
import { useEffect, useState, useRef } from "react";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
CardFooter,
|
||||
} from "@/components/ui/card";
|
||||
import { LoadingSpinner } from "@/app/chat/chat_search/LoadingSpinner";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { logout } from "@/lib/user";
|
||||
import { FiLogOut } from "react-icons/fi";
|
||||
|
||||
export default function WaitingOnSetupPage({
|
||||
minimalUserInfo,
|
||||
}: {
|
||||
minimalUserInfo: MinimalUserInfo;
|
||||
}) {
|
||||
const [progress, setProgress] = useState(0);
|
||||
const [setupStage, setSetupStage] = useState<string>(
|
||||
"Setting up your account"
|
||||
);
|
||||
const progressRef = useRef<number>(0);
|
||||
const animationRef = useRef<number>();
|
||||
const startTimeRef = useRef<number>(Date.now());
|
||||
const [isReady, setIsReady] = useState(false);
|
||||
const pollIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
// Setup stages that will cycle through during the loading process
|
||||
const setupStages = [
|
||||
"Setting up your account",
|
||||
"Configuring workspace",
|
||||
"Preparing resources",
|
||||
"Setting up permissions",
|
||||
"Finalizing setup",
|
||||
];
|
||||
|
||||
// Function to poll the /api/me endpoint
|
||||
const pollAccountStatus = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/me");
|
||||
if (response.status === 200) {
|
||||
// Account is ready
|
||||
setIsReady(true);
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error polling account status:", error);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Handle logout
|
||||
const handleLogout = async () => {
|
||||
try {
|
||||
await logout();
|
||||
window.location.href = "/auth/login";
|
||||
} catch (error) {
|
||||
console.error("Failed to logout:", error);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// Animation setup for progress bar
|
||||
let lastUpdateTime = 0;
|
||||
const updateInterval = 100;
|
||||
const normalAnimationDuration = 30000; // 30 seconds for normal animation
|
||||
const acceleratedAnimationDuration = 1500; // 1.5 seconds for accelerated animation after ready
|
||||
let currentStageIndex = 0;
|
||||
let lastStageUpdateTime = Date.now();
|
||||
const stageRotationInterval = 3000; // Rotate stages every 3 seconds
|
||||
|
||||
const animate = (timestamp: number) => {
|
||||
const elapsedTime = timestamp - startTimeRef.current;
|
||||
const now = Date.now();
|
||||
|
||||
// Calculate progress using different curves based on ready status
|
||||
const maxProgress = 99;
|
||||
let progress;
|
||||
|
||||
if (isReady) {
|
||||
// Accelerate to 100% when account is ready
|
||||
progress =
|
||||
maxProgress +
|
||||
(100 - maxProgress) *
|
||||
((now - startTimeRef.current) / acceleratedAnimationDuration);
|
||||
if (progress >= 100) progress = 100;
|
||||
} else {
|
||||
// Slower progress when still waiting
|
||||
progress =
|
||||
maxProgress * (1 - Math.exp(-elapsedTime / normalAnimationDuration));
|
||||
}
|
||||
|
||||
// Update progress if enough time has passed
|
||||
if (timestamp - lastUpdateTime > updateInterval) {
|
||||
progressRef.current = progress;
|
||||
setProgress(Math.round(progress * 10) / 10);
|
||||
|
||||
// Cycle through setup stages
|
||||
if (now - lastStageUpdateTime > stageRotationInterval && !isReady) {
|
||||
currentStageIndex = (currentStageIndex + 1) % setupStages.length;
|
||||
setSetupStage(setupStages[currentStageIndex]);
|
||||
lastStageUpdateTime = now;
|
||||
}
|
||||
|
||||
lastUpdateTime = timestamp;
|
||||
}
|
||||
|
||||
// Continue animation if not completed
|
||||
if (progress < 100) {
|
||||
animationRef.current = requestAnimationFrame(animate);
|
||||
} else if (progress >= 100) {
|
||||
// Redirect when progress reaches 100%
|
||||
setSetupStage("Setup complete!");
|
||||
setTimeout(() => {
|
||||
window.location.href = "/chat";
|
||||
}, 500);
|
||||
}
|
||||
};
|
||||
|
||||
// Start animation
|
||||
startTimeRef.current = performance.now();
|
||||
animationRef.current = requestAnimationFrame(animate);
|
||||
|
||||
// Start polling the /api/me endpoint
|
||||
pollIntervalRef.current = setInterval(async () => {
|
||||
const ready = await pollAccountStatus();
|
||||
if (ready) {
|
||||
// If ready, we'll let the animation handle the redirect
|
||||
console.log("Account is ready!");
|
||||
}
|
||||
}, 2000); // Poll every 2 seconds
|
||||
|
||||
// Attempt to complete tenant setup
|
||||
// completeTenantSetup(minimalUserInfo.email).catch((error) => {
|
||||
// console.error("Failed to complete tenant setup:", error);
|
||||
// });
|
||||
|
||||
// Cleanup function
|
||||
return () => {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
}
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current);
|
||||
}
|
||||
};
|
||||
}, [isReady, minimalUserInfo.email]);
|
||||
useEffect(() => {
|
||||
completeTenantSetup(minimalUserInfo.email).catch((error) => {
|
||||
console.error("Failed to complete tenant setup:", error);
|
||||
});
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<main className="min-h-screen bg-gradient-to-b from-white to-neutral-50 dark:from-neutral-900 dark:to-neutral-950">
|
||||
<div className="absolute top-0 w-full">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
<div className="w-full max-w-md">
|
||||
<div className="flex flex-col items-center mb-8">
|
||||
<Logo height={80} width={80} className="mx-auto w-fit mb-6" />
|
||||
<h1 className="text-2xl font-bold text-neutral-900 dark:text-white">
|
||||
Account Setup
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
<Card className="border-neutral-200 dark:border-neutral-800 shadow-lg">
|
||||
<CardHeader className="pb-0">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className="relative">
|
||||
<LoadingSpinner size="medium" className="text-primary" />
|
||||
</div>
|
||||
<h2 className="text-lg font-semibold text-neutral-800 dark:text-neutral-200">
|
||||
{setupStage}
|
||||
</h2>
|
||||
</div>
|
||||
<span className="text-sm font-medium text-neutral-500 dark:text-neutral-400">
|
||||
{progress}%
|
||||
</span>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="pt-4">
|
||||
{/* Progress bar */}
|
||||
<div className="w-full h-2 bg-neutral-200 dark:bg-neutral-700 rounded-full mb-6 overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-blue-600 dark:bg-blue-500 rounded-full transition-all duration-300 ease-out"
|
||||
style={{ width: `${progress}%` }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="flex flex-col space-y-1">
|
||||
<Text className="text-neutral-800 dark:text-neutral-200 font-medium">
|
||||
Welcome,{" "}
|
||||
<span className="font-semibold">
|
||||
{minimalUserInfo?.email}
|
||||
</span>
|
||||
</Text>
|
||||
<Text className="text-neutral-600 dark:text-neutral-400 text-sm">
|
||||
We're setting up your account. This may take a few moments.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className="bg-blue-50 dark:bg-blue-900/20 rounded-lg p-4 border border-blue-100 dark:border-blue-800">
|
||||
<Text className="text-sm text-blue-700 dark:text-blue-300">
|
||||
You'll be redirected automatically when your account is
|
||||
ready. If you're not redirected within a minute, please
|
||||
refresh the page.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
<CardFooter className="flex justify-end pt-4">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleLogout}
|
||||
className="text-neutral-600 dark:text-neutral-300"
|
||||
>
|
||||
<FiLogOut className="mr-1" />
|
||||
Logout
|
||||
</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
||||
@@ -191,7 +191,6 @@ export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
|
||||
onChange={(e) => setNewFolderName(e.target.value)}
|
||||
className="text-sm font-medium bg-transparent outline-none w-full pb-1 border-b border-background-500 transition-colors duration-200"
|
||||
onKeyDown={(e) => {
|
||||
e.stopPropagation();
|
||||
if (e.key === "Enter") {
|
||||
handleEdit();
|
||||
}
|
||||
|
||||
@@ -303,6 +303,7 @@ const FolderItem = ({
|
||||
key={chatSession.id}
|
||||
chatSession={chatSession}
|
||||
isSelected={chatSession.id === currentChatId}
|
||||
skipGradient={isDragOver}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
|
||||
@@ -32,17 +32,21 @@ export function ChatSessionDisplay({
|
||||
chatSession,
|
||||
search,
|
||||
isSelected,
|
||||
skipGradient,
|
||||
closeSidebar,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
foldersExisting,
|
||||
isDragging,
|
||||
}: {
|
||||
chatSession: ChatSession;
|
||||
isSelected: boolean;
|
||||
search?: boolean;
|
||||
skipGradient?: boolean;
|
||||
closeSidebar?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
foldersExisting?: boolean;
|
||||
isDragging?: boolean;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
@@ -234,12 +238,8 @@ export function ChatSessionDisplay({
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
onChange={(e) => {
|
||||
setChatName(e.target.value);
|
||||
}}
|
||||
onChange={(e) => setChatName(e.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
event.stopPropagation();
|
||||
|
||||
if (event.key === "Enter") {
|
||||
onRename();
|
||||
event.preventDefault();
|
||||
|
||||
@@ -264,6 +264,7 @@ export function PagesTab({
|
||||
>
|
||||
<ChatSessionDisplay
|
||||
chatSession={chat}
|
||||
foldersExisting={foldersExisting}
|
||||
isSelected={currentChatId === chat.id}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
|
||||
import { AppProvider } from "@/components/context/AppProvider";
|
||||
import { PHProvider } from "./providers";
|
||||
import { getCurrentUserSS } from "@/lib/userSS";
|
||||
import { getCurrentUserSS, getMinimalUserInfoSS } from "@/lib/userSS";
|
||||
import { Suspense } from "react";
|
||||
import PostHogPageView from "./PostHogPageView";
|
||||
import Script from "next/script";
|
||||
@@ -30,6 +30,8 @@ import { ThemeProvider } from "next-themes";
|
||||
import CloudError from "@/components/errorPages/CloudErrorPage";
|
||||
import Error from "@/components/errorPages/ErrorPage";
|
||||
import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage";
|
||||
import CompleteTenantSetupPage from "./auth/waiting-on-setup/WaitingOnSetup";
|
||||
import WaitingOnSetupPage from "./auth/waiting-on-setup/WaitingOnSetup";
|
||||
|
||||
const inter = Inter({
|
||||
subsets: ["latin"],
|
||||
@@ -70,12 +72,17 @@ export default async function RootLayout({
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const [combinedSettings, assistantsData, user] = await Promise.all([
|
||||
fetchSettingsSS(),
|
||||
fetchAssistantData(),
|
||||
getCurrentUserSS(),
|
||||
]);
|
||||
const [combinedSettings, assistantsData, user, minimalUserInfo] =
|
||||
await Promise.all([
|
||||
fetchSettingsSS(),
|
||||
fetchAssistantData(),
|
||||
getCurrentUserSS(),
|
||||
getMinimalUserInfoSS(),
|
||||
]);
|
||||
|
||||
// if (!user && minimalUserInfo) {
|
||||
// return <CompleteTenantSetupPage />;
|
||||
// }
|
||||
const productGating =
|
||||
combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE;
|
||||
|
||||
@@ -135,6 +142,11 @@ export default async function RootLayout({
|
||||
if (productGating === ApplicationStatus.GATED_ACCESS) {
|
||||
return getPageContent(<AccessRestrictedPage />);
|
||||
}
|
||||
if (!user && minimalUserInfo) {
|
||||
return getPageContent(
|
||||
<WaitingOnSetupPage minimalUserInfo={minimalUserInfo} />
|
||||
);
|
||||
}
|
||||
|
||||
if (!combinedSettings) {
|
||||
return getPageContent(
|
||||
|
||||
@@ -40,12 +40,8 @@ export const ConnectorTitle = ({
|
||||
const typedConnector = connector as Connector<GithubConfig>;
|
||||
additionalMetadata.set(
|
||||
"Repo",
|
||||
typedConnector.connector_specific_config.repositories
|
||||
? `${typedConnector.connector_specific_config.repo_owner}/${
|
||||
typedConnector.connector_specific_config.repositories.includes(",")
|
||||
? "multiple repos"
|
||||
: typedConnector.connector_specific_config.repositories
|
||||
}`
|
||||
typedConnector.connector_specific_config.repo_name
|
||||
? `${typedConnector.connector_specific_config.repo_owner}/${typedConnector.connector_specific_config.repo_name}`
|
||||
: `${typedConnector.connector_specific_config.repo_owner}/*`
|
||||
);
|
||||
} else if (connector.source === "gitlab") {
|
||||
|
||||
@@ -190,12 +190,10 @@ export const connectorConfigs: Record<
|
||||
fields: [
|
||||
{
|
||||
type: "text",
|
||||
query: "Enter the repository name(s):",
|
||||
label: "Repository Name(s)",
|
||||
name: "repositories",
|
||||
query: "Enter the repository name:",
|
||||
label: "Repository Name",
|
||||
name: "repo_name",
|
||||
optional: false,
|
||||
description:
|
||||
"For multiple repositories, enter comma-separated names (e.g., repo1,repo2,repo3)",
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -1360,7 +1358,7 @@ export interface WebConfig {
|
||||
|
||||
export interface GithubConfig {
|
||||
repo_owner: string;
|
||||
repositories: string; // Comma-separated list of repository names
|
||||
repo_name: string;
|
||||
include_prs: boolean;
|
||||
include_issues: boolean;
|
||||
}
|
||||
|
||||
14
web/src/lib/tenant.ts
Normal file
14
web/src/lib/tenant.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
export async function completeTenantSetup(email: string): Promise<void> {
|
||||
const response = await fetch(`/api/tenants/complete-setup`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ email }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Failed to complete tenant setup: ${errorText}`);
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,12 @@ interface UserPreferences {
|
||||
temperature_override_enabled: boolean;
|
||||
}
|
||||
|
||||
export interface MinimalUserInfo {
|
||||
id: string;
|
||||
email: string;
|
||||
is_active: boolean;
|
||||
}
|
||||
|
||||
export enum UserRole {
|
||||
LIMITED = "limited",
|
||||
BASIC = "basic",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { cookies } from "next/headers";
|
||||
import { User } from "./types";
|
||||
import { MinimalUserInfo, User } from "./types";
|
||||
import { buildUrl } from "./utilsSS";
|
||||
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
@@ -162,6 +162,30 @@ export const logoutSS = async (
|
||||
}
|
||||
}
|
||||
};
|
||||
export const getMinimalUserInfoSS =
|
||||
async (): Promise<MinimalUserInfo | null> => {
|
||||
try {
|
||||
const response = await fetch(buildUrl("/me-info"), {
|
||||
credentials: "include",
|
||||
headers: {
|
||||
cookie: (await cookies())
|
||||
.getAll()
|
||||
.map((cookie) => `${cookie.name}=${cookie.value}`)
|
||||
.join("; "),
|
||||
},
|
||||
|
||||
next: { revalidate: 0 },
|
||||
});
|
||||
if (!response.ok) {
|
||||
console.error("Failed to fetch minimal user info");
|
||||
return null;
|
||||
}
|
||||
return await response.json();
|
||||
} catch (e) {
|
||||
console.log(`Error fetching minimal user info: ${e}`);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
export const getCurrentUserSS = async (): Promise<User | null> => {
|
||||
try {
|
||||
|
||||
Reference in New Issue
Block a user