mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 11:45:47 +00:00
Compare commits
8 Commits
ci_python_
...
v0.11.0-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3457d35d5a | ||
|
|
a5217ac8ce | ||
|
|
a23004abf8 | ||
|
|
e79aa7d2f5 | ||
|
|
6efe24f6bc | ||
|
|
dd40b12c89 | ||
|
|
b31c59642e | ||
|
|
0a94bbac7b |
@@ -48,7 +48,6 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -83,21 +82,19 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
@@ -238,19 +221,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
tenant_id = await get_or_create_tenant_id(user_create.email)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -271,7 +242,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
@@ -292,7 +263,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -308,19 +281,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
tenant_id = await get_or_create_tenant_id(account_email)
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -371,9 +337,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
|
||||
@@ -119,10 +119,10 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
|
||||
@@ -277,12 +277,14 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
|
||||
@@ -30,7 +30,6 @@ from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import get_tenant_id_for_email
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
@@ -66,7 +65,8 @@ from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_comm
|
||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||
from ee.danswer.server.tenants.billing import register_tenant_users
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -359,7 +359,7 @@ def handle_new_chat_message(
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
logger.exception("Error in chat message streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
finally:
|
||||
|
||||
@@ -279,7 +279,7 @@ def get_answer_with_quote(
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in search answer streaming: {e}")
|
||||
logger.exception("Error in search answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
||||
0
backend/ee/danswer/auth/tenant.py
Normal file
0
backend/ee/danswer/auth/tenant.py
Normal file
@@ -7,7 +7,6 @@ from fastapi import Response
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import get_jwt_strategy
|
||||
from danswer.auth.users import get_tenant_id_for_email
|
||||
from danswer.auth.users import User
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@@ -15,7 +14,6 @@ from danswer.db.notification import create_notification
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.auth.users import current_cloud_superuser
|
||||
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@@ -23,15 +21,9 @@ from ee.danswer.server.tenants.access import control_plane_dep
|
||||
from ee.danswer.server.tenants.billing import fetch_billing_information
|
||||
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.danswer.server.tenants.models import BillingInformation
|
||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||
from ee.danswer.server.tenants.models import ImpersonateRequest
|
||||
from ee.danswer.server.tenants.models import ProductGatingRequest
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
@@ -40,52 +32,6 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
def create_tenant(
|
||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||
) -> dict[str, str]:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
email = create_tenant_request.initial_admin_email
|
||||
token = None
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
try:
|
||||
if not ensure_schema_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
run_alembic_migrations(tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Tenant {tenant_id} created successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import text
|
||||
import aiohttp # Async HTTP client
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.auth.users import exceptions
|
||||
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
@@ -15,49 +16,136 @@ from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.setup import setup_danswer
|
||||
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.danswer.server.tenants.schema_management import drop_schema
|
||||
from ee.danswer.server.tenants.schema_management import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
async def get_or_create_tenant_id(email: str) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
return tenant_id
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
async def create_tenant(email: str) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
return tenant_id
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
|
||||
async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
logger.info(f"Provisioning tenant: {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(tenant_id: str, email: str) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {EXPECTED_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {"tenant_id": tenant_id, "email": email}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/create",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to create tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
@@ -81,65 +169,3 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
|
||||
|
||||
def ensure_schema_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# For now, we're implementing a primitive mapping between users and tenants.
|
||||
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
76
backend/ee/danswer/server/tenants/schema_management.py
Normal file
76
backend/ee/danswer/server/tenants/schema_management.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import logging
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
)
|
||||
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_schema_if_not_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def drop_schema(tenant_id: str) -> None:
|
||||
if not tenant_id.isidentifier():
|
||||
raise ValueError("Invalid tenant_id.")
|
||||
with get_sqlalchemy_engine().connect() as connection:
|
||||
connection.execute(
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
68
backend/ee/danswer/server/tenants/user_mapping.py
Normal file
68
backend/ee/danswer/server/tenants/user_mapping.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
Reference in New Issue
Block a user