1
0
forked from github/onyx

Compare commits

...

2 Commits

Author SHA1 Message Date
pablodanswer
0cc09c8b4d nits 2024-11-02 10:12:12 -07:00
pablodanswer
ec8ae2b5f4 add super user 2024-11-02 10:11:11 -07:00
25 changed files with 237 additions and 29 deletions

View File

@@ -93,9 +93,9 @@ from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -510,19 +510,23 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def write_token(self, user: User) -> str:
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> JWTStrategy:
def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,

View File

@@ -478,3 +478,7 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")

View File

@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -57,10 +57,10 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -37,10 +37,10 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -10,6 +10,7 @@ from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.server.tenants.access import control_plane_dep
@@ -100,6 +101,7 @@ def check_router_auth(
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):
found_auth = True
break

View File

@@ -57,6 +57,7 @@ class UserInfo(BaseModel):
oidc_expiry: datetime | None = None
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
organization_name: str | None = None
@classmethod
@@ -65,6 +66,7 @@ class UserInfo(BaseModel):
user: User,
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
organization_name: str | None = None,
) -> "UserInfo":
return cls(
@@ -90,6 +92,7 @@ class UserInfo(BaseModel):
oidc_expiry=user.oidc_expiry if TRACK_EXTERNAL_IDP_EXPIRY else None,
current_token_created_at=current_token_created_at,
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
)

View File

@@ -35,6 +35,7 @@ from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.auth import get_total_users_count
@@ -476,6 +477,7 @@ def verify_user_logged_in(
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
# to enforce user verification here - the frontend always wants to get the info about
# the current user regardless of if they are currently verified
if user is None:
# if auth type is disabled, return a dummy user with preferences from
# the key-value store
@@ -502,6 +504,7 @@ def verify_user_logged_in(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
organization_name=organization_name,
)

View File

@@ -21,7 +21,7 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -5,7 +5,6 @@ from collections.abc import MutableMapping
from logging.handlers import RotatingFileHandler
from typing import Any
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import DEV_LOGGING_ENABLED
from shared_configs.configs import LOG_FILE_NAME
from shared_configs.configs import LOG_LEVEL
@@ -13,6 +12,7 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logging.addLevelName(logging.INFO + 5, "NOTICE")

View File

@@ -1,9 +1,13 @@
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.models import User
@@ -68,3 +72,19 @@ def get_default_admin_user_emails_() -> list[str]:
if seed_config and seed_config.admin_user_emails:
return seed_config.admin_user_emails
return []
async def current_cloud_superuser(
request: Request,
user: User | None = Depends(current_admin_user),
) -> User | None:
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
if api_key != SUPER_CLOUD_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
if user and user.email not in SUPER_USERS:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be a cloud superuser to perform this action.",
)
return user

View File

@@ -28,8 +28,8 @@ from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -15,8 +15,8 @@ from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def get_api_key_email_pattern() -> str:

View File

@@ -29,6 +29,7 @@ def fetch_chat_sessions_eagerly_by_time(
filters: list[ColumnElement | BinaryExpression] = [
ChatSession.time_created.between(start, end)
]
if initial_id:
filters.append(ChatSession.id < initial_id)
subquery = (

View File

@@ -11,9 +11,9 @@ from fastapi import Response
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
@@ -22,11 +22,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
tenant_id = POSTGRES_DEFAULT_SCHEMA
if MULTI_TENANT:
tenant_id = _get_tenant_id_from_request(request, logger)
tenant_id = (
_get_tenant_id_from_request(request, logger)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return await call_next(request)

View File

@@ -2,29 +2,36 @@ import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from danswer.auth.users import auth_backend
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_jwt_strategy
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import User
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
from danswer.db.notification import create_notification
from danswer.db.users import get_user_by_email
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
@@ -132,3 +139,30 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_jwt_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -29,3 +29,7 @@ class BillingInformation(BaseModel):
class CheckoutSessionCreationResponse(BaseModel):
id: str
class ImpersonateRequest(BaseModel):
email: str

View File

@@ -1,4 +1,3 @@
import contextvars
import os
from typing import List
from urllib.parse import urlparse
@@ -134,10 +133,6 @@ MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"

View File

@@ -0,0 +1,8 @@
import contextvars
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Context variable for the current tenant id
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)

View File

@@ -0,0 +1,132 @@
"use client";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { useUser } from "@/components/user/UserProvider";
import { redirect, useRouter } from "next/navigation";
import { Formik, Form, Field } from "formik";
import * as Yup from "yup";
import { usePopup } from "@/components/admin/connectors/Popup";
const ImpersonateSchema = Yup.object().shape({
email: Yup.string().email("Invalid email").required("Required"),
apiKey: Yup.string().required("Required"),
});
export default function ImpersonatePage() {
const router = useRouter();
const { user, isLoadingUser, isCloudSuperuser } = useUser();
const { popup, setPopup } = usePopup();
if (isLoadingUser) {
return null;
}
if (!user) {
redirect("/auth/login");
}
if (!isCloudSuperuser) {
redirect("/search");
}
const handleImpersonate = async (values: {
email: string;
apiKey: string;
}) => {
try {
const response = await fetch("/api/tenants/impersonate", {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${values.apiKey}`,
},
body: JSON.stringify({ email: values.email }),
credentials: "same-origin",
});
if (!response.ok) {
const errorData = await response.json();
setPopup({
message: errorData.detail || "Failed to impersonate user",
type: "error",
});
} else {
router.push("/search");
}
} catch (error) {
setPopup({
message:
error instanceof Error ? error.message : "Failed to impersonate user",
type: "error",
});
}
};
return (
<AuthFlowContainer>
{popup}
<div className="absolute top-10x w-full">
<HealthCheckBanner />
</div>
<div className="flex flex-col w-full justify-center">
<h2 className="text-center text-xl text-strong font-bold mb-8">
Impersonate User
</h2>
<Formik
initialValues={{ email: "", apiKey: "" }}
validationSchema={ImpersonateSchema}
onSubmit={handleImpersonate}
>
{({ errors, touched }) => (
<Form className="flex flex-col items-stretch gap-y-2">
<div className="relative">
<Field
type="email"
name="email"
placeholder="Enter user email to impersonate"
className="w-full px-4 py-3 border border-border rounded-lg bg-input focus:outline-none focus:ring-2 focus:ring-primary transition-all duration-200"
/>
<div className="h-8">
{errors.email && touched.email && (
<div className="text-red-500 text-sm mt-1">
{errors.email}
</div>
)}
</div>
</div>
<div className="relative">
<Field
type="password"
name="apiKey"
placeholder="Enter API Key"
className="w-full px-4 py-3 border border-border rounded-lg bg-input focus:outline-none focus:ring-2 focus:ring-primary transition-all duration-200"
/>
<div className="h-8">
{errors.apiKey && touched.apiKey && (
<div className="text-red-500 text-sm mt-1">
{errors.apiKey}
</div>
)}
</div>
</div>
<button
type="submit"
className="w-full py-3 bg-accent text-white rounded-lg hover:bg-accent/90 transition-colors"
>
Impersonate User
</button>
</Form>
)}
</Formik>
<div className="text-sm text-text-500 mt-4 text-center px-4 rounded-md">
Note: This feature is only available for @danswer.ai administrators
</div>
</div>
</AuthFlowContainer>
);
}

View File

@@ -5,11 +5,6 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
export const LoginText = () => {
const settings = useContext(SettingsContext);
// if (!settings) {
// throw new Error("SettingsContext is not available");
// }
return (
<>
Log In to{" "}

View File

@@ -7,7 +7,7 @@ export default function AuthFlowContainer({
}) {
return (
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
<div className="w-full max-w-md bg-black p-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<div className="w-full max-w-md bg-black pt-8 pb-4 px-8 mx-4 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<Logo width={70} height={70} />
{children}
</div>

View File

@@ -11,6 +11,7 @@ interface UserContextType {
isAdmin: boolean;
isCurator: boolean;
refreshUser: () => Promise<void>;
isCloudSuperuser: boolean;
}
const UserContext = createContext<UserContextType | undefined>(undefined);
@@ -67,6 +68,7 @@ export function UserProvider({
refreshUser,
isAdmin: upToDateUser?.role === UserRole.ADMIN,
isCurator: upToDateUser?.role === UserRole.CURATOR,
isCloudSuperuser: upToDateUser?.is_cloud_superuser ?? false,
}}
>
{children}

View File

@@ -42,6 +42,7 @@ export interface User {
current_token_created_at?: Date;
current_token_expiry_length?: number;
oidc_expiry?: Date;
is_cloud_superuser?: boolean;
organization_name: string | null;
}