Compare commits

...

3 Commits

Author SHA1 Message Date
Chris Weaver
b5f009572a Fix non-default schema in KV store (#4655)
* Fix non-default schema in KV store

* Fix custom schema
2025-05-06 10:54:12 -07:00
Weves
efab12a962 Remove migration 2025-04-23 13:17:48 -07:00
Weves
a5234a398b Fix google drive group sync 2025-04-22 08:43:05 -07:00
8 changed files with 173 additions and 171 deletions

View File

@@ -5,8 +5,6 @@ Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
@@ -17,34 +15,36 @@ depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=8000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=8000),
existing_nullable=False,
)
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
pass
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=8000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=8000),
type_=sa.TEXT(),
existing_nullable=False,
)
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
pass

View File

@@ -1,3 +1,5 @@
from googleapiclient.errors import HttpError # type: ignore
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
@@ -12,6 +14,7 @@ logger = setup_logger()
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
admin_service: AdminService,
) -> dict[str, tuple[set[str], set[str]]]:
"""
This builds a map of drive ids to their members (group and user emails).
@@ -20,6 +23,8 @@ def _get_drive_members(
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
}
"""
# fetches shared drives only
drive_ids = google_drive_connector.get_all_drive_ids()
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
@@ -28,20 +33,44 @@ def _get_drive_members(
google_drive_connector.primary_admin_email,
)
admin_user_info = (
admin_service.users()
.get(userKey=google_drive_connector.primary_admin_email)
.execute()
)
is_admin = admin_user_info.get("isAdmin", False) or admin_user_info.get(
"isDelegatedAdmin", False
)
for drive_id in drive_ids:
group_emails: set[str] = set()
user_emails: set[str] = set()
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
try:
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
# can only set `useDomainAdminAccess` to true if the user
# is an admin
useDomainAdminAccess=is_admin,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
except HttpError as e:
if e.status_code == 404:
logger.warning(
f"Error getting permissions for drive id {drive_id}. "
f"User '{google_drive_connector.primary_admin_email}' likely "
f"does not have access to this drive. Exception: {e}"
)
else:
raise e
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
return drive_id_to_members_map
@@ -132,7 +161,7 @@ def gdrive_group_sync(
)
# Get all drive members
drive_id_to_members_map = _get_drive_members(google_drive_connector)
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
# Get all group emails
all_group_emails = _get_all_groups(

View File

@@ -28,6 +28,7 @@ from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -39,13 +40,10 @@ router = APIRouter(prefix="/auth/saml")
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
async with get_async_session_context() as session:
async with get_async_session_context_manager() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:

View File

@@ -92,7 +92,7 @@ from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -252,7 +252,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user_email)
async with get_async_session_with_tenant(tenant_id) as db_session:
async with get_async_session_context_manager(tenant_id) as db_session:
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
@@ -295,7 +295,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
user: User
async with get_async_session_with_tenant(tenant_id) as db_session:
async with get_async_session_context_manager(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
@@ -395,7 +395,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
async with get_async_session_context_manager(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_in_whitelist(account_email, tenant_id)
@@ -634,7 +634,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return None
# Create a tenant-specific session
async with get_async_session_with_tenant(tenant_id) as tenant_session:
async with get_async_session_context_manager(tenant_id) as tenant_session:
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
tenant_session, User
)

View File

@@ -17,7 +17,7 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import get_api_key_email_pattern
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_async_session_context_manager
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
@@ -55,8 +55,9 @@ def get_total_users_count(db_session: Session) -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
async with get_async_session_context_manager() as session:
count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)

View File

@@ -10,7 +10,7 @@ from contextlib import asynccontextmanager
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from typing import ContextManager
from typing import AsyncContextManager
import asyncpg # type: ignore
import boto3
@@ -46,6 +46,7 @@ from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -352,18 +353,6 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
# Listen for events on the synchronous Session class
@event.listens_for(Session, "after_begin")
def _set_search_path(
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
) -> None:
"""Every time a new transaction is started,
set the search_path from the session's info."""
tenant_id = session.info.get("tenant_id")
if tenant_id:
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
engine = get_sqlalchemy_async_engine()
AsyncSessionLocal = sessionmaker( # type: ignore
bind=engine,
@@ -372,33 +361,6 @@ AsyncSessionLocal = sessionmaker( # type: ignore
)
@asynccontextmanager
async def get_async_session_with_tenant(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = get_current_tenant_id()
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise ValueError("Invalid tenant ID")
async with AsyncSessionLocal() as session:
session.sync_session.info["tenant_id"] = tenant_id
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
try:
yield session
finally:
pass
@contextmanager
def get_session_with_current_tenant() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
@@ -416,17 +378,24 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _set_search_path_on_checkout__listener(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
"""Listener to make sure we ALWAYS set the search path on checkout."""
tenant_id = get_current_tenant_id()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
"""
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", set_search_path_on_checkout)
event.listen(engine, "checkout", _set_search_path_on_checkout__listener)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
@@ -457,57 +426,84 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
def get_session() -> Generator[Session, None, None]:
"""For use w/ Depends for FastAPI endpoints.
Has some additional validation, and likely should be merged
with get_session_context_manager in the future."""
tenant_id = get_current_tenant_id()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(detail="User must authenticate")
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with get_session_context_manager() as db_session:
yield db_session
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
@contextlib.contextmanager
def get_session_context_manager() -> Generator[Session, None, None]:
"""Context manager for database sessions."""
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
def get_session() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(detail="User must authenticate")
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
def _set_search_path_on_transaction__listener(
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
) -> None:
"""Every time a new transaction is started,
set the search_path from the session's info."""
tenant_id = session.info.get("tenant_id")
if tenant_id:
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
tenant_id = get_current_tenant_id()
async def get_async_session(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
"""For use w/ Depends for *async* FastAPI endpoints.
For standard `async with ... as ...` use, use get_async_session_context_manager.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# set the search path on sync session as well to be extra safe
event.listen(
async_session.sync_session,
"after_begin",
_set_search_path_on_transaction__listener,
)
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await async_session.execute(
text(
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# don't need to set the search path for self-hosted + default schema
# this is also true for sync sessions, but just not adding it there for
# now to simplify / not change too much
if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
def get_session_context_manager() -> ContextManager[Session]:
"""Context manager for database sessions."""
return contextlib.contextmanager(get_session_generator_with_tenant)()
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory
def get_async_session_context_manager(
tenant_id: str | None = None,
) -> AsyncContextManager[AsyncSession]:
return asynccontextmanager(get_async_session)(tenant_id)
async def warm_up_connections(

View File

@@ -1,25 +1,15 @@
import json
from collections.abc import Iterator
from contextlib import contextmanager
from typing import cast
from fastapi import HTTPException
from redis.client import Redis
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.engine import is_valid_schema_name
from onyx.db.engine import get_session_context_manager
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -31,26 +21,11 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self, redis_client: Redis | None = None) -> None:
self.tenant_id = get_current_tenant_id()
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
@contextmanager
def _get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise BasicAuthenticationError(detail="User must authenticate")
if not is_valid_schema_name(self.tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
yield session
self.redis_client = get_redis_client()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres
@@ -64,8 +39,8 @@ class PgRedisKVStore(KeyValueStore):
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self._get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
with get_session_context_manager() as db_session:
obj = db_session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
@@ -73,9 +48,9 @@ class PgRedisKVStore(KeyValueStore):
obj = KVStore(
key=key, value=plain_val, encrypted_value=encrypted_val
) # type: ignore
session.query(KVStore).filter_by(key=key).delete() # just in case
session.add(obj)
session.commit()
db_session.query(KVStore).filter_by(key=key).delete() # just in case
db_session.add(obj)
db_session.commit()
def load(self, key: str) -> JSON_ro:
try:
@@ -86,8 +61,8 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
with self._get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
with get_session_context_manager() as db_session:
obj = db_session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
@@ -111,8 +86,8 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with self._get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
with get_session_context_manager() as db_session:
result = db_session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError
session.commit()
db_session.commit()

View File

@@ -141,7 +141,10 @@ else:
# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE = "public"
POSTGRES_DEFAULT_SCHEMA = (
os.environ.get("POSTGRES_DEFAULT_SCHEMA") or POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
)
DEFAULT_REDIS_PREFIX = os.environ.get("DEFAULT_REDIS_PREFIX") or "default"