mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 11:45:47 +00:00
Compare commits
3 Commits
ci_python_
...
v0.26.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5f009572a | ||
|
|
efab12a962 | ||
|
|
a5234a398b |
@@ -5,8 +5,6 @@ Revises: 6a804aeb4830
|
||||
Create Date: 2025-04-01 15:07:14.977435
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -17,34 +15,36 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=8000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=8000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -12,6 +14,7 @@ logger = setup_logger()
|
||||
|
||||
def _get_drive_members(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
admin_service: AdminService,
|
||||
) -> dict[str, tuple[set[str], set[str]]]:
|
||||
"""
|
||||
This builds a map of drive ids to their members (group and user emails).
|
||||
@@ -20,6 +23,8 @@ def _get_drive_members(
|
||||
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
|
||||
}
|
||||
"""
|
||||
|
||||
# fetches shared drives only
|
||||
drive_ids = google_drive_connector.get_all_drive_ids()
|
||||
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
|
||||
@@ -28,20 +33,44 @@ def _get_drive_members(
|
||||
google_drive_connector.primary_admin_email,
|
||||
)
|
||||
|
||||
admin_user_info = (
|
||||
admin_service.users()
|
||||
.get(userKey=google_drive_connector.primary_admin_email)
|
||||
.execute()
|
||||
)
|
||||
is_admin = admin_user_info.get("isAdmin", False) or admin_user_info.get(
|
||||
"isDelegatedAdmin", False
|
||||
)
|
||||
|
||||
for drive_id in drive_ids:
|
||||
group_emails: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
|
||||
try:
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
# can only set `useDomainAdminAccess` to true if the user
|
||||
# is an admin
|
||||
useDomainAdminAccess=is_admin,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
except HttpError as e:
|
||||
if e.status_code == 404:
|
||||
logger.warning(
|
||||
f"Error getting permissions for drive id {drive_id}. "
|
||||
f"User '{google_drive_connector.primary_admin_email}' likely "
|
||||
f"does not have access to this drive. Exception: {e}"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
|
||||
return drive_id_to_members_map
|
||||
|
||||
@@ -132,7 +161,7 @@ def gdrive_group_sync(
|
||||
)
|
||||
|
||||
# Get all drive members
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector)
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
|
||||
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_groups(
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_context_manager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -39,13 +40,10 @@ router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
||||
|
||||
async with get_async_session_context() as session:
|
||||
async with get_async_session_context_manager() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
|
||||
@@ -92,7 +92,7 @@ from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_async_session_context_manager
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -252,7 +252,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(user_email)
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
@@ -295,7 +295,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user: User
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
@@ -395,7 +395,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
@@ -634,7 +634,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
return None
|
||||
|
||||
# Create a tenant-specific session
|
||||
async with get_async_session_with_tenant(tenant_id) as tenant_session:
|
||||
async with get_async_session_context_manager(tenant_id) as tenant_session:
|
||||
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
|
||||
tenant_session, User
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import get_api_key_email_pattern
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_async_session_context_manager
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
@@ -55,8 +55,9 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
stmt = select(func.count(User.id))
|
||||
async with get_async_session_context_manager() as session:
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
stmt = select(count_stmt)
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
|
||||
@@ -10,7 +10,7 @@ from contextlib import asynccontextmanager
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
from typing import AsyncContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
@@ -46,6 +46,7 @@ from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -352,18 +353,6 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
# Listen for events on the synchronous Session class
|
||||
@event.listens_for(Session, "after_begin")
|
||||
def _set_search_path(
|
||||
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Every time a new transaction is started,
|
||||
set the search_path from the session's info."""
|
||||
tenant_id = session.info.get("tenant_id")
|
||||
if tenant_id:
|
||||
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
|
||||
|
||||
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
AsyncSessionLocal = sessionmaker( # type: ignore
|
||||
bind=engine,
|
||||
@@ -372,33 +361,6 @@ AsyncSessionLocal = sessionmaker( # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
raise ValueError("Invalid tenant ID")
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
session.sync_session.info["tenant_id"] = tenant_id
|
||||
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
text(
|
||||
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_current_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -416,17 +378,24 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _set_search_path_on_checkout__listener(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
"""Listener to make sure we ALWAYS set the search path on checkout."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
"""
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
event.listen(engine, "checkout", _set_search_path_on_checkout__listener)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
@@ -457,57 +426,84 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
cursor.close()
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""For use w/ Depends for FastAPI endpoints.
|
||||
|
||||
Has some additional validation, and likely should be merged
|
||||
with get_session_context_manager in the future."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
@contextlib.contextmanager
|
||||
def get_session_context_manager() -> Generator[Session, None, None]:
|
||||
"""Context manager for database sessions."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
def _set_search_path_on_transaction__listener(
|
||||
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Every time a new transaction is started,
|
||||
set the search_path from the session's info."""
|
||||
tenant_id = session.info.get("tenant_id")
|
||||
if tenant_id:
|
||||
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async def get_async_session(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""For use w/ Depends for *async* FastAPI endpoints.
|
||||
|
||||
For standard `async with ... as ...` use, use get_async_session_context_manager.
|
||||
"""
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# set the search path on sync session as well to be extra safe
|
||||
event.listen(
|
||||
async_session.sync_session,
|
||||
"after_begin",
|
||||
_set_search_path_on_transaction__listener,
|
||||
)
|
||||
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await async_session.execute(
|
||||
text(
|
||||
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
# don't need to set the search path for self-hosted + default schema
|
||||
# this is also true for sync sessions, but just not adding it there for
|
||||
# now to simplify / not change too much
|
||||
if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
yield async_session
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
"""Context manager for database sessions."""
|
||||
return contextlib.contextmanager(get_session_generator_with_tenant)()
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
def get_async_session_context_manager(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncContextManager[AsyncSession]:
|
||||
return asynccontextmanager(get_async_session)(tenant_id)
|
||||
|
||||
|
||||
async def warm_up_connections(
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from redis.client import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -31,26 +21,11 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
self.tenant_id = get_current_tenant_id()
|
||||
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
|
||||
|
||||
@contextmanager
|
||||
def _get_session(self) -> Iterator[Session]:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
|
||||
yield session
|
||||
self.redis_client = get_redis_client()
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
@@ -64,8 +39,8 @@ class PgRedisKVStore(KeyValueStore):
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
with self._get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
with get_session_context_manager() as db_session:
|
||||
obj = db_session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
obj.encrypted_value = encrypted_val
|
||||
@@ -73,9 +48,9 @@ class PgRedisKVStore(KeyValueStore):
|
||||
obj = KVStore(
|
||||
key=key, value=plain_val, encrypted_value=encrypted_val
|
||||
) # type: ignore
|
||||
session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
session.add(obj)
|
||||
session.commit()
|
||||
db_session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
db_session.add(obj)
|
||||
db_session.commit()
|
||||
|
||||
def load(self, key: str) -> JSON_ro:
|
||||
try:
|
||||
@@ -86,8 +61,8 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self._get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
with get_session_context_manager() as db_session:
|
||||
obj = db_session.query(KVStore).filter_by(key=key).first()
|
||||
if not obj:
|
||||
raise KvKeyNotFoundError
|
||||
|
||||
@@ -111,8 +86,8 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self._get_session() as session:
|
||||
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
with get_session_context_manager() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise KvKeyNotFoundError
|
||||
session.commit()
|
||||
db_session.commit()
|
||||
|
||||
@@ -141,7 +141,10 @@ else:
|
||||
# Multi-tenancy configuration
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
|
||||
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
|
||||
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE = "public"
|
||||
POSTGRES_DEFAULT_SCHEMA = (
|
||||
os.environ.get("POSTGRES_DEFAULT_SCHEMA") or POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
)
|
||||
DEFAULT_REDIS_PREFIX = os.environ.get("DEFAULT_REDIS_PREFIX") or "default"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user