Compare commits

...

7 Commits

Author SHA1 Message Date
pablonyx
f150b2ad27 k 2025-02-18 16:08:47 -08:00
pablonyx
819a246cf6 nit 2025-02-18 15:05:10 -08:00
pablonyx
46704cce0e merge 2025-02-18 14:18:11 -08:00
pablonyx
fb0da64763 nit 2025-02-17 19:19:17 -08:00
pablonyx
ff14a866e3 k 2025-02-17 19:19:17 -08:00
pablonyx
4d6d320191 k 2025-02-17 19:19:17 -08:00
pablonyx
5e6774e118 strict tenant id enforcement 2025-02-17 19:19:17 -08:00
68 changed files with 390 additions and 357 deletions

View File

@@ -21,7 +21,7 @@ logger = setup_logger()
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -44,7 +44,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
@@ -62,7 +62,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
return await call_next(request)
except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
logger.exception(f"Error in tenant ID middleware: {str(e)}")
raise
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id:
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie

View File

@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -271,12 +271,12 @@ def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
@@ -329,7 +329,6 @@ def handle_slack_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -337,7 +336,7 @@ def handle_slack_oauth_callback(
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
@@ -523,7 +522,6 @@ def handle_google_drive_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -531,7 +529,7 @@ def handle_google_drive_oauth_callback(
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (

View File

@@ -28,7 +28,7 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -94,7 +94,7 @@ User Group rate limits
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -41,14 +41,15 @@ from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
@@ -57,13 +58,14 @@ router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
tenant_id: str | None = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@@ -72,15 +74,15 @@ async def get_anonymous_user_path_api(
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
tenant_id: str = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
@@ -101,7 +103,7 @@ async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
@@ -150,14 +152,17 @@ async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
# Fetch tenant_id and current tenant's information
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
@@ -181,6 +186,8 @@ async def create_subscription_session(
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
@@ -197,7 +204,7 @@ async def impersonate_user(
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id) as tenant_session:
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
@@ -221,8 +228,9 @@ async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"

View File

@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,

View File

@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()

View File

@@ -86,7 +86,6 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -103,6 +102,7 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -193,7 +193,7 @@ def verify_email_is_invited(email: str) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -819,8 +819,9 @@ async def current_limited_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> User | None:
tenant_id = get_current_tenant_id()
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
)

View File

@@ -33,6 +33,7 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -58,13 +59,35 @@ else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
class TenantAwareTask(Task):
"""A custom base Task that sets tenant_id in a contextvar before running."""
abstract = True # So Celery knows not to register this as a real task.
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Grab tenant_id from the kwargs, or fallback to default if missing.
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
# Set the context var
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Actually run the task now
try:
return super().__call__(*args, **kwargs)
finally:
# Clear or reset after the task runs
# so it does not leak into any subsequent tasks on the same worker process
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**kwds: Any,
**other_kwargs: Any,
) -> None:
pass
@@ -117,7 +140,7 @@ def on_task_postrun(
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@@ -201,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -287,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
@@ -439,24 +462,6 @@ class TenantContextFilter(logging.Filter):
return True
@task_prerun.connect
def set_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to set tenant ID in context var before task starts."""
tenant_id = (
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if kwargs
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@task_postrun.connect
def reset_tenant_id(
sender: Any | None = None,

View File

@@ -132,6 +132,7 @@ class DynamicTenantScheduler(PersistentScheduler):
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
return new_schedule
@@ -256,3 +257,4 @@ def on_setup_logging(
celery_app.conf.beat_scheduler = DynamicTenantScheduler
celery_app.conf.task_default_base = app_base.TenantAwareTask

View File

@@ -20,6 +20,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.heavy")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -21,6 +21,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -23,6 +23,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.light")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -20,6 +20,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -47,6 +47,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.primary")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -101,7 +102,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -158,7 +159,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
@@ -234,7 +235,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -27,7 +27,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -62,8 +62,8 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -77,14 +77,14 @@ def check_for_connector_deletion_task(
try:
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
@@ -277,7 +277,7 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -287,7 +287,7 @@ def monitor_connector_deletion_taskset(
)
return
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,

View File

@@ -45,7 +45,7 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -119,13 +119,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -140,7 +140,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
@@ -189,7 +189,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
@@ -247,7 +247,7 @@ def try_creating_permissions_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -321,7 +321,7 @@ def connector_permission_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -378,7 +378,7 @@ def connector_permission_sync_generator_task(
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -480,7 +480,8 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),

View File

@@ -39,7 +39,7 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -122,8 +122,8 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -140,7 +140,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
@@ -230,7 +230,7 @@ def try_creating_external_group_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -296,7 +296,7 @@ def connector_external_group_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -357,7 +357,7 @@ def connector_external_group_sync_generator_task(
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -408,7 +408,7 @@ def connector_external_group_sync_generator_task(
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -459,7 +459,6 @@ def validate_external_group_sync_fences(
)
lock_beat.reacquire()
return

View File

@@ -50,7 +50,7 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
@@ -348,12 +348,13 @@ def monitor_ccpair_indexing_taskset(
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
redis_client = get_redis_client()
redis_client_replica = get_redis_replica_client()
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
@@ -391,7 +392,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# 1/3: KICKOFF
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
@@ -412,7 +413,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# gather cc_pair_ids
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -422,7 +423,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
@@ -500,7 +501,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
@@ -552,7 +553,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_ccpair_indexing_taskset(
tenant_id, key_bytes, redis_client_replica, db_session
)
@@ -583,8 +584,8 @@ def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
tenant_id: str | None,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -635,7 +636,7 @@ def connector_indexing_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
if redis_connector.delete.fenced:
raise SimpleJobException(
@@ -729,7 +730,7 @@ def connector_indexing_task(
redis_connector_index.set_fence(payload)
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
@@ -907,9 +908,8 @@ def connector_indexing_proxy_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
tenant_id,
)
if not job:
@@ -929,7 +929,7 @@ def connector_indexing_proxy_task(
redis_connector_index = redis_connector.new_index(search_settings_id)
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
@@ -977,7 +977,7 @@ def connector_indexing_proxy_task(
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
@@ -1005,7 +1005,7 @@ def connector_indexing_proxy_task(
if result.exception_str is not None:
# print with exception
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
@@ -1045,7 +1045,7 @@ def connector_indexing_proxy_task(
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
@@ -1095,7 +1095,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
try:
locked = True
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
for attempt in old_attempts:
task_logger.info(
@@ -1131,5 +1131,5 @@ def cleanup_checkpoint_task(
self: Task, *, index_attempt_id: int, tenant_id: str | None
) -> None:
"""Clean up a checkpoint for a given index attempt"""
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cleanup_checkpoint(db_session, index_attempt_id)

View File

@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
@@ -318,7 +318,7 @@ def validate_indexing_fences(
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,

View File

@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import LLMProvider
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | N
return None
# Then update the database with the fetched models
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Get the default LLM provider
default_provider = (
db_session.query(LLMProvider)

View File

@@ -26,7 +26,8 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -42,7 +43,6 @@ from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
lock_monitoring: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client(tenant_id=tenant_id)
redis_std = get_redis_client()
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
@@ -771,12 +771,11 @@ def cloud_check_alembic() -> bool | None:
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
with get_session_with_shared_schema() as session:
try:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
if result_scalar is None:
raise ValueError("Alembic version should not be None.")

View File

@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
@shared_task(
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -108,8 +108,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -127,14 +127,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -182,7 +182,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -337,7 +337,7 @@ def connector_pruning_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -395,7 +395,7 @@ def connector_pruning_generator_task(
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,

View File

@@ -27,7 +27,7 @@ from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -79,7 +79,7 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
action = "skip"
chunks_affected = 0
@@ -205,7 +205,7 @@ def document_by_cc_pair_cleanup_task(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(

View File

@@ -34,7 +34,7 @@ from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import DocumentSet
@@ -84,8 +84,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -98,7 +98,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
try:
# 1/3: KICKOFF
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
try_generate_stale_document_sync_tasks(
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
)
@@ -106,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
# region document set scan
lock_beat.reacquire()
document_set_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
@@ -117,7 +117,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
for document_set_id in document_set_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
try_generate_document_set_sync_tasks(
self.app, document_set_id, db_session, r, lock_beat, tenant_id
)
@@ -136,7 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
pass
else:
usergroup_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
@@ -146,7 +146,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
for usergroup_id in usergroup_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
try_generate_user_group_sync_tasks(
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
@@ -167,7 +167,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
@@ -177,7 +177,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
@@ -523,12 +523,12 @@ def monitor_document_set_taskset(
max_retries=3,
)
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
self: Task, document_id: str, *, tenant_id: str | None
) -> bool:
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,

View File

@@ -1,5 +1,5 @@
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
def emit_background_error(
@@ -9,5 +9,5 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -28,7 +28,7 @@ from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
@@ -244,7 +244,7 @@ def _run_indexing(
"""
start_time = time.monotonic() # jsut used for logging
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
@@ -370,7 +370,7 @@ def _run_indexing(
document_count = 0
chunk_count = 0
try:
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -430,7 +430,7 @@ def _run_indexing(
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
@@ -439,7 +439,7 @@ def _run_indexing(
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
@@ -503,7 +503,7 @@ def _run_indexing(
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
@@ -516,7 +516,7 @@ def _run_indexing(
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
@@ -533,7 +533,7 @@ def _run_indexing(
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
@@ -555,7 +555,7 @@ def _run_indexing(
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
@@ -569,7 +569,7 @@ def _run_indexing(
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
@@ -587,7 +587,7 @@ def _run_indexing(
memory_tracer.stop()
raise e
else:
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
@@ -609,7 +609,7 @@ def _run_indexing(
memory_tracer.stop()
elapsed_time = time.monotonic() - start_time
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
@@ -669,7 +669,7 @@ def run_indexing_entrypoint(
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
@@ -690,7 +690,7 @@ def run_indexing_entrypoint(
f"credentials='{credential_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(

View File

@@ -143,7 +143,7 @@ from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -342,7 +342,7 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id

View File

@@ -181,7 +181,7 @@ class LocalFileConnector(LoadConnector):
documents: list[Document] = []
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_tenant(self.tenant_id) as db_session:
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(

View File

@@ -23,7 +23,6 @@ from onyx.context.search.preprocessing.access_filters import (
from onyx.context.search.retrieval.search_runner import (
remove_stop_words_and_punctuation,
)
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
from onyx.llm.interfaces import LLM
@@ -35,6 +34,7 @@ from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -166,7 +166,7 @@ def retrieval_preprocessing(
time_cutoff=time_filter or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -17,7 +17,7 @@ from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
def get_api_key_email_pattern() -> str:
@@ -71,7 +71,7 @@ def insert_api_key(
std_password_helper = PasswordHelper()
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()

View File

@@ -1,5 +1,4 @@
import contextlib
import json
import os
import re
import ssl
@@ -16,7 +15,6 @@ from typing import ContextManager
import asyncpg # type: ignore
import boto3
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
@@ -44,13 +42,13 @@ from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -265,7 +263,7 @@ def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
with get_session_with_shared_schema() as session:
result = session.execute(
text(
f"""
@@ -353,38 +351,6 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
async def get_current_tenant_id(request: Request) -> str:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
logger.debug(
f"Token data not found or expired in Redis, defaulting to {current_value}"
)
return current_value
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
# Listen for events on the synchronous Session class
@event.listens_for(Session, "after_begin")
def _set_search_path(
@@ -410,7 +376,7 @@ async def get_async_session_with_tenant(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
@@ -433,82 +399,80 @@ async def get_async_session_with_tenant(
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
def get_session_with_current_tenant() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
# Used in multi tenant mode when need to refer to the shared `public` schema
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
def get_session_with_shared_schema() -> Generator[Session, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
yield session
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@contextmanager
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
with engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
with engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
finally:
cursor.close()
)
finally:
cursor.close()
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
def get_session() -> Generator[Session, None, None]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(detail="User must authenticate")
@@ -523,7 +487,7 @@ def get_session() -> Generator[Session, None, None]:
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:

View File

@@ -13,7 +13,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
@@ -189,7 +189,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_default_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -73,7 +73,7 @@ from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.factory import get_shared_kv_store
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -193,7 +193,7 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
kv_store = get_shared_kv_store()
needs_reindexing = False
try:
@@ -240,6 +240,9 @@ class VespaIndex(DocumentIndex):
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
logger.error(
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
)
raise RuntimeError(
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
)
@@ -277,7 +280,7 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
kv_store = get_shared_kv_store()
needs_reindexing = False
try:

View File

@@ -36,7 +36,9 @@ def build_vespa_filters(
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
or_clause = " or ".join(eq_elems)
return f"({or_clause}) and "
result = f"({or_clause}) and "
return result
def _build_time_filter(
cutoff: datetime | None,

View File

@@ -320,7 +320,13 @@ def eml_to_text(file: IO[Any]) -> str:
text_content = []
for part in message.walk():
if part.get_content_type().startswith("text/plain"):
text_content.append(part.get_payload())
payload = part.get_payload()
if isinstance(payload, str):
text_content.append(payload)
elif isinstance(payload, list):
text_content.extend(item for item in payload if isinstance(item, str))
else:
logger.warning(f"Unexpected payload type: {type(payload)}")
return TEXT_SECTION_SEPARATOR.join(text_content)

View File

@@ -57,7 +57,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
response = requests.get(url)
response.raise_for_status()
@@ -76,7 +76,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
unique_id = str(uuid4())
file_store = get_default_file_store(db_session)
file_store.save_file(

View File

@@ -1,8 +1,18 @@
from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.store import PgRedisKVStore
from shared_configs.configs import DEFAULT_REDIS_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def get_kv_store() -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore()
def get_shared_kv_store() -> KeyValueStore:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(DEFAULT_REDIS_PREFIX)
try:
return get_kv_store()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -415,7 +415,7 @@ def _build_continue_in_web_ui_block(
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,

View File

@@ -114,7 +114,7 @@ def handle_generate_answer_button(
thread_ts=thread_ts,
)
with get_session_with_tenant(client.tenant_id) as db_session:
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
@@ -155,7 +155,7 @@ def handle_slack_feedback(
) -> None:
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
@@ -246,7 +246,7 @@ def handle_followup_button(
tag_ids: list[str] = []
group_ids: list[str] = []
with get_session_with_tenant(client.tenant_id) as db_session:
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)

View File

@@ -211,7 +211,7 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if message_info.email:
add_slack_user_if_not_exists(db_session, message_info.email)

View File

@@ -23,6 +23,7 @@ from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import RetrievalDetails
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
@@ -86,7 +87,7 @@ def handle_regular_answer(
user = None
if message_info.is_bot_dm:
if message_info.email:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
@@ -95,7 +96,7 @@ def handle_regular_answer(
# This way slack flow always has a persona
persona = slack_channel_config.persona
if not persona:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
document_set_names = [
document_set.name for document_set in persona.document_sets
@@ -107,7 +108,7 @@ def handle_regular_answer(
]
prompt = persona.prompts[0] if persona.prompts else None
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
expecting_search_result = persona_has_search_tool(persona.id, db_session)
# TODO: Add in support for Slack to truncate messages based on max LLM context
@@ -156,7 +157,7 @@ def handle_regular_answer(
def _get_slack_answer(
new_message_request: CreateChatMessageRequest, onyx_user: User | None
) -> ChatOnyxBotResponse:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
packets = stream_chat_message_objects(
new_msg_req=new_message_request,
user=onyx_user,
@@ -196,7 +197,7 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
answer_request = prepare_chat_message_request(
message_text=user_message.message,
user=user,

View File

@@ -251,7 +251,7 @@ class SlackbotHandler:
"""
all_tenants = get_all_tenant_ids()
token: Token[str]
token: Token[str | None]
# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
@@ -300,7 +300,7 @@ class SlackbotHandler:
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
bots: list[SlackBot] = []
try:
bots = list(fetch_slack_bots(db_session=db_session))
@@ -340,7 +340,7 @@ class SlackbotHandler:
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
@@ -564,7 +564,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
@@ -788,13 +788,13 @@ def process_message(
client=client.web_client, channel_id=channel
)
token: Token[str] | None = None
token: Token[str | None] | None = None
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
try:
with get_session_with_tenant(client.tenant_id) as db_session:
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,

View File

@@ -583,7 +583,7 @@ def slack_usage_report(
logger.warning("Unable to find sender email")
if sender_email is not None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
optional_telemetry(

View File

@@ -28,6 +28,8 @@ from onyx.configs.app_configs import REDIS_SSL_CERT_REQS
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
from onyx.utils.logger import setup_logger
from shared_configs.configs import DEFAULT_REDIS_PREFIX
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -152,14 +154,10 @@ class RedisPool:
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
def get_client(self, tenant_id: str) -> Redis:
return TenantRedis(tenant_id, connection_pool=self._pool)
def get_replica_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
def get_replica_client(self, tenant_id: str) -> Redis:
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
@staticmethod
@@ -221,14 +219,36 @@ redis_pool = RedisPool()
# print(value.decode()) # Output: 'value'
def get_redis_client(*, tenant_id: str | None) -> Redis:
def get_redis_client(
*,
# This argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
if tenant_id is None:
tenant_id = get_current_tenant_id()
return redis_pool.get_client(tenant_id)
def get_redis_replica_client(*, tenant_id: str | None) -> Redis:
def get_redis_replica_client(
*,
# this argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
if tenant_id is None:
tenant_id = get_current_tenant_id()
return redis_pool.get_replica_client(tenant_id)
def get_shared_redis_client() -> Redis:
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
def get_shared_redis_replica_client() -> Redis:
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,

View File

@@ -35,8 +35,6 @@ from onyx.db.connector_credential_pair import (
)
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.document import get_documents_for_cc_pair
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -62,6 +60,7 @@ from onyx.server.documents.models import PaginatedReturn
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -106,8 +105,9 @@ def get_cc_pair_full_info(
cc_pair_id: int,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo:
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id, db_session, user, get_editable=False
)
@@ -172,7 +172,6 @@ def update_cc_pair_status(
status_update_request: CCStatusUpdateRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""This method returns nearly immediately. It simply sets some signals and
optimistically assumes any running background processes will clean themselves up.
@@ -180,6 +179,8 @@ def update_cc_pair_status(
Returns HTTPStatus.OK if everything finished.
"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -340,9 +341,9 @@ def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
@@ -356,7 +357,7 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.prune.fenced:
@@ -372,7 +373,7 @@ def prune_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_prune_generator_task(
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
primary_app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -414,9 +415,9 @@ def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers permissions sync on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
@@ -430,7 +431,7 @@ def sync_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.permissions.fenced:
@@ -446,7 +447,7 @@ def sync_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
primary_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -488,9 +489,9 @@ def sync_cc_pair_groups(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers group sync on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
@@ -504,7 +505,7 @@ def sync_cc_pair_groups(
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.external_group_sync.fenced:
@@ -520,7 +521,7 @@ def sync_cc_pair_groups(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
primary_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(

View File

@@ -79,7 +79,6 @@ from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import IndexingMode
@@ -119,6 +118,7 @@ from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -610,8 +610,8 @@ def get_connector_indexing_status(
get_editable: bool = Query(
False, description="If true, return editable document sets"
),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> list[ConnectorIndexingStatus]:
tenant_id = get_current_tenant_id()
indexing_statuses: list[ConnectorIndexingStatus] = []
if MOCK_CONNECTOR_FILE_PATH:
@@ -774,8 +774,9 @@ def create_connector_from_model(
connector_data: ConnectorUpdateRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> ObjectCreationIdResponse:
tenant_id = get_current_tenant_id()
try:
_validate_connector_allowed(connector_data.source)
@@ -813,15 +814,8 @@ def create_connector_with_mock_credential(
connector_data: ConnectorUpdateRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse:
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
and associate_credential_to_connector should be combined.
The intent of this endpoint is to handle connectors that don't need credentials,
AKA web, file, etc ... but there isn't any reason a single endpoint couldn't
server this purpose.
"""
tenant_id = get_current_tenant_id()
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
@@ -952,10 +946,10 @@ def connector_run_once(
run_info: RunConnectorRequest,
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[int]:
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
tenant_id = get_current_tenant_id()
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids

View File

@@ -17,13 +17,13 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import OAuthConnector
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from onyx.utils.subclasses import find_all_subclasses_in_dir
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -89,9 +89,10 @@ def oauth_authorize(
source: DocumentSource,
desired_return_url: Annotated[str | None, Query()] = None,
_: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> AuthorizeResponse:
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
tenant_id = get_current_tenant_id()
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
@@ -140,7 +141,6 @@ def oauth_callback(
state: Annotated[str, Query()],
db_session: Session = Depends(get_session),
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CallbackResponse:
"""Handles the OAuth callback and exchanges the code for tokens"""
oauth_connectors = _discover_oauth_connectors()
@@ -151,7 +151,7 @@ def oauth_callback(
connector_cls = oauth_connectors[source]
# get state from redis
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client = get_redis_client()
oauth_state_bytes = cast(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)

View File

@@ -18,7 +18,6 @@ from onyx.auth.users import current_user
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
@@ -56,6 +55,7 @@ from onyx.server.models import DisplayPriorityRequest
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -201,8 +201,9 @@ def create_persona(
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> PersonaSnapshot:
tenant_id = get_current_tenant_id()
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids

View File

@@ -20,7 +20,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_for_
from onyx.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user
@@ -39,6 +38,7 @@ from onyx.server.manage.models import BoostUpdateRequest
from onyx.server.manage.models import HiddenUpdateRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter(prefix="/manage")
logger = setup_logger()
@@ -139,8 +139,9 @@ def create_deletion_attempt_for_connector_id(
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
tenant_id = get_current_tenant_id()
connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id

View File

@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.constants import MilestoneRecordType
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import ChannelConfig
from onyx.db.models import User
@@ -36,6 +35,7 @@ from onyx.server.manage.models import SlackChannelConfigCreationRequest
from onyx.server.manage.validate_tokens import validate_app_token
from onyx.server.manage.validate_tokens import validate_bot_token
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter(prefix="/manage")
@@ -228,8 +228,9 @@ def create_bot(
slack_bot_creation_request: SlackBotCreationRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> SlackBot:
tenant_id = get_current_tenant_id()
validate_app_token(slack_bot_creation_request.app_token)
validate_bot_token(slack_bot_creation_request.bot_token)

View File

@@ -42,8 +42,6 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.api_key import is_api_key_email_address
from onyx.db.auth import get_total_users_count
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import AccessToken
from onyx.db.models import User
@@ -69,6 +67,7 @@ from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter()
@@ -266,13 +265,13 @@ def bulk_invite_users(
) -> int:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""
tenant_id = get_current_tenant_id()
if current_user is None:
raise HTTPException(
status_code=400, detail="Auth is disabled, cannot invite users"
)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
new_invited_emails = []
email: str
@@ -314,7 +313,7 @@ def bulk_invite_users(
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
)(tenant_id, get_total_users_count(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
@@ -341,10 +340,10 @@ def remove_invited_user(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
tenant_id = get_current_tenant_id()
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
@@ -354,7 +353,7 @@ def remove_invited_user(
if MULTI_TENANT:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
)(tenant_id, get_total_users_count(db_session))
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "
@@ -531,8 +530,9 @@ def get_current_token_creation(
def verify_user_logged_in(
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> UserInfo:
tenant_id = get_current_tenant_id()
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
# to enforce user verification here - the frontend always wants to get the info about
# the current user regardless of if they are currently verified

View File

@@ -11,7 +11,6 @@ from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_documents_by_cc_pair
from onyx.db.document import get_ingestion_documents
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings
@@ -24,6 +23,7 @@ from onyx.server.onyx_api.models import DocMinimalInfo
from onyx.server.onyx_api.models import IngestionDocument
from onyx.server.onyx_api.models import IngestionResult
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -69,8 +69,9 @@ def upsert_ingestion_doc(
doc_info: IngestionDocument,
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> IngestionResult:
tenant_id = get_current_tenant_id()
doc_info.document.from_ingestion_api = True
document = Document.from_base(doc_info.document)

View File

@@ -44,7 +44,6 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
@@ -83,6 +82,7 @@ from onyx.server.query_and_chat.token_limit import check_token_rate_limits
from onyx.utils.headers import get_custom_tool_additional_request_headers
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -376,7 +376,6 @@ def handle_new_chat_message(
user: User | None = Depends(current_chat_accesssible_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
tenant_id: str = Depends(get_current_tenant_id),
) -> StreamingResponse:
"""
This endpoint is both used for all the following purposes:
@@ -398,6 +397,7 @@ def handle_new_chat_message(
Returns:
StreamingResponse: Streams the response to the new chat message.
"""
tenant_id = get_current_tenant_id()
logger.debug(f"Received new chat message: {chat_message_req.message}")
if (
@@ -407,7 +407,7 @@ def handle_new_chat_message(
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",

View File

@@ -22,7 +22,6 @@ from onyx.db.chat import get_search_docs_for_chat_message
from onyx.db.chat import get_valid_messages_from_query_sessions
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
@@ -37,6 +36,7 @@ from onyx.server.query_and_chat.models import SearchSessionDetailResponse
from onyx.server.query_and_chat.models import SourceTag
from onyx.server.query_and_chat.models import TagResponse
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -49,8 +49,9 @@ def admin_search(
question: AdminSearchRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> AdminSearchResponse:
tenant_id = get_current_tenant_id()
query = question.query
logger.notice(f"Received admin search query: {query}")
user_acl_filters = build_access_filters_for_user(user, db_session)

View File

@@ -21,7 +21,7 @@ from onyx.db.models import User
from onyx.db.token_limit import fetch_all_global_token_rate_limits
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -41,7 +41,7 @@ def check_token_rate_limits(
versioned_rate_limit_strategy = fetch_versioned_implementation(
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
)
return versioned_rate_limit_strategy(user, CURRENT_TENANT_ID_CONTEXTVAR.get())
return versioned_rate_limit_strategy(user, get_current_tenant_id())
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
@@ -54,7 +54,7 @@ Global rate limits
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)

View File

@@ -6,7 +6,7 @@ from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -26,7 +26,7 @@ def load_settings() -> Settings:
logger.error(f"Error loading settings from KV store: {str(e)}")
settings = Settings()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
try:
@@ -49,7 +49,7 @@ def load_settings() -> Settings:
def store_settings(settings: Settings) -> None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
if settings.anonymous_user_enabled is not None:

View File

@@ -252,7 +252,7 @@ def setup_vespa(
logger.notice("Vespa setup complete.")
return True
except Exception:
logger.notice(
logger.exception(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)
time.sleep(WAIT_SECONDS)

View File

@@ -17,7 +17,7 @@ from requests import JSONDecodeError
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import FileOrigin
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
@@ -205,7 +205,7 @@ class CustomTool(BaseTool):
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_default_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_id = str(uuid.uuid4())
@@ -328,7 +328,7 @@ class CustomTool(BaseTool):
# Load files from storage
files = []
with get_session_with_default_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
for file_id in response.tool_result.file_ids:

View File

@@ -107,7 +107,7 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
# This will always be the case for authenticated API requests
if MULTI_TENANT:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id != POSTGRES_DEFAULT_SCHEMA:
if tenant_id != POSTGRES_DEFAULT_SCHEMA and tenant_id is not None:
# Strip tenant_ prefix and take first 8 chars for cleaner logs
tenant_display = tenant_id.removeprefix(TENANT_ID_PREFIX)
short_tenant = (

View File

@@ -12,7 +12,7 @@ from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
@@ -22,7 +22,7 @@ from onyx.utils.variable_functionality import (
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
_CACHED_UUID: str | None = None
@@ -75,7 +75,7 @@ def _get_or_generate_instance_domain() -> str | None: #
try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with get_session_with_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
@@ -90,12 +90,12 @@ def optional_telemetry(
record_type: RecordType,
data: dict,
user_id: str | None = None,
tenant_id: str | None = None,
tenant_id: str | None = None, # Allows for override of tenant_id
) -> None:
if DISABLE_TELEMETRY:
return
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = tenant_id or get_current_tenant_id()
try:

View File

@@ -47,7 +47,7 @@ def get_user_id(user_email: str) -> tuple[UUID, str]:
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)
with get_session_with_tenant(tenant_id) as session:
with get_session_with_tenant(tenant_id=tenant_id) as session:
user = get_user_by_email(user_email, session)
if user is None:
raise ValueError(f"User not found for email: {user_email}")

View File

@@ -41,6 +41,7 @@ from sqlalchemy import and_
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
@@ -64,6 +65,7 @@ from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -473,7 +475,7 @@ def get_document_acls(
def get_current_chunk_count(
document_id: str, index_name: str, tenant_id: str
) -> int | None:
with get_session_with_tenant(tenant_id=tenant_id) as session:
with get_session_with_current_tenant() as session:
return (
session.query(Document.chunk_count)
.filter(Document.id == document_id)
@@ -513,7 +515,7 @@ class VespaDebugging:
# Sample random documents and compare chunk counts
mismatches = []
no_chunks = []
with get_session_with_tenant(tenant_id=self.tenant_id) as session:
with get_session_with_current_tenant() as session:
# Get a sample of random documents
from sqlalchemy import func
@@ -796,6 +798,7 @@ def main() -> None:
args = parser.parse_args()
vespa_debug = VespaDebugging(args.tenant_id)
CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id)
if args.action == "delete-all-documents":
if not args.tenant_id:
parser.error("--tenant-id is required for delete-all-documents action")

View File

@@ -133,6 +133,7 @@ else:
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
DEFAULT_REDIS_PREFIX = os.environ.get("DEFAULT_REDIS_PREFIX") or "default"
async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:

View File

@@ -1,10 +1,14 @@
import contextvars
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Context variable for the current tenant id
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
str | None
] = contextvars.ContextVar(
"current_tenant_id", default=None if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)

View File

@@ -1296,6 +1296,7 @@ export function ChatPage({
getLastSuccessfulMessageId(currMessageHistory) || systemMessage;
const stack = new CurrentMessageFIFO();
updateCurrentMessageFIFO(stack, {
signal: controller.signal,
message: currMessage,

View File

@@ -239,13 +239,13 @@ export function FilterPopup({
};
const isDocumentSetSelected = (docSet: DocumentSet) =>
filterManager.selectedDocumentSets.includes(docSet.id.toString());
filterManager.selectedDocumentSets.includes(docSet.name);
const toggleDocumentSet = (docSet: DocumentSet) => {
filterManager.setSelectedDocumentSets((prev) =>
prev.includes(docSet.id.toString())
? prev.filter((id) => id !== docSet.id.toString())
: [...prev, docSet.id.toString()]
prev.includes(docSet.name)
? prev.filter((id) => id !== docSet.name)
: [...prev, docSet.name]
);
};