1
0
forked from github/onyx

Add support for db proxy (#4932)

* Split up engine file

* Switch to schema_translate_map

* Fix mass serach/replace

* Remove unused

* Fix mypy

* Fix

* Add back __init__.py

* kg fix for new session management

Adding "<tenant_id>" in front of all views.

* additional kg fix

* better handling

* improve naming

---------

Co-authored-by: joachim-danswer <joachim@danswer.ai>
This commit is contained in:
Chris Weaver
2025-06-23 17:19:07 -07:00
committed by GitHub
parent 8d5136fe8b
commit e7376e9dc2
186 changed files with 769 additions and 672 deletions

View File

@@ -1,12 +1,12 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
@@ -24,7 +24,7 @@ from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)

View File

@@ -0,0 +1,136 @@
"""update_kg_trigger_functions
Revision ID: 36e9220ab794
Revises: c9e2cd766c29
Create Date: 2025-06-22 17:33:25.833733
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
down_revision = "c9e2cd766c29"
branch_labels = None
depends_on = None
def _get_tenant_contextvar(session: Session) -> str:
"""Get the current schema for the migration"""
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
if isinstance(current_tenant, str):
return current_tenant
else:
raise ValueError("Current tenant is not a string")
def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
# Create kg_entity trigger to update kg_entity.name and its trigrams
tenant_id = _get_tenant_contextvar(session)
alphanum_pattern = r"[^a-z0-9]+"
truncate_length = 1000
function = "update_kg_entity_name"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
RETURNS TRIGGER AS $$
DECLARE
name text;
cleaned_name text;
BEGIN
-- Set name to semantic_id if document_id is not NULL
IF NEW.document_id IS NOT NULL THEN
SELECT lower(semantic_id) INTO name
FROM "{tenant_id}".document
WHERE id = NEW.document_id;
ELSE
name = lower(NEW.name);
END IF;
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".kg_entity')
op.execute(
f"""
CREATE TRIGGER {trigger}
BEFORE INSERT OR UPDATE OF name
ON "{tenant_id}".kg_entity
FOR EACH ROW
EXECUTE FUNCTION "{tenant_id}".{function}();
"""
)
# Create kg_entity trigger to update kg_entity.name and its trigrams
function = "update_kg_entity_name_from_doc"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
RETURNS TRIGGER AS $$
DECLARE
doc_name text;
cleaned_name text;
BEGIN
doc_name = lower(NEW.semantic_id);
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
doc_name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams for all entities referencing this document
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".document')
op.execute(
f"""
CREATE TRIGGER {trigger}
AFTER UPDATE OF semantic_id
ON "{tenant_id}".document
FOR EACH ROW
EXECUTE FUNCTION "{tenant_id}".{function}();
"""
)
def downgrade() -> None:
pass

View File

@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from onyx.db.engine import build_connection_string
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.models import PublicBase
# this is the Alembic Config object, which provides

View File

@@ -16,7 +16,7 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id

View File

@@ -13,7 +13,7 @@ from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task

View File

@@ -6,7 +6,7 @@ from celery import shared_task
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import delete_task_with_id
from onyx.utils.logger import setup_logger

View File

@@ -13,7 +13,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST

View File

@@ -47,8 +47,8 @@ from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus

View File

@@ -41,7 +41,7 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus

View File

@@ -19,7 +19,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT

View File

@@ -3,7 +3,7 @@ from ee.onyx.external_permissions.sync_params import get_all_censoring_enabled_s
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -22,7 +22,7 @@ def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
for every single chunk.
"""
all_censoring_enabled_sources = get_all_censoring_enabled_sources()
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source

View File

@@ -10,7 +10,7 @@ from ee.onyx.external_permissions.salesforce.utils import (
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -44,7 +44,7 @@ def _get_objects_access_for_user_email_from_salesforce(
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
@@ -217,7 +217,7 @@ def censor_salesforce_chunks(
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,

View File

@@ -19,7 +19,7 @@ from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
router = APIRouter(prefix="/analytics")

View File

@@ -17,7 +17,7 @@ from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id_for_user,
)
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client

View File

@@ -26,7 +26,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.server.utils import BasicAuthenticationError

View File

@@ -13,7 +13,7 @@ from ee.onyx.db.standard_answer import remove_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer_category
from onyx.auth.users import current_admin_user
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.models import StandardAnswer
from onyx.server.manage.models import StandardAnswerCategory

View File

@@ -11,7 +11,7 @@ from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine import is_valid_schema_name
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA

View File

@@ -12,10 +12,10 @@ from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@@ -25,12 +25,12 @@ from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@@ -33,11 +33,11 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from shared_configs.contextvars import get_current_tenant_id
class GoogleDriveOAuth:

View File

@@ -17,11 +17,11 @@ from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from shared_configs.contextvars import get_current_tenant_id
class SlackOAuth:

View File

@@ -40,7 +40,7 @@ from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer

View File

@@ -31,7 +31,7 @@ from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import get_prompt_by_id
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit

View File

@@ -37,7 +37,7 @@ from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import TaskStatus
from onyx.db.file_record import get_query_history_export_files
from onyx.db.models import ChatSession

View File

@@ -14,7 +14,7 @@ from ee.onyx.db.usage_export import get_usage_report_data
from ee.onyx.db.usage_export import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.auth.users import current_admin_user
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.file_store.constants import STANDARD_CHUNK_SIZE

View File

@@ -27,9 +27,9 @@ from onyx.auth.users import get_user_manager
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.utils.logger import setup_logger

View File

@@ -19,7 +19,7 @@ from ee.onyx.server.enterprise_settings.store import (
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -235,7 +235,7 @@ def seed_db() -> None:
logger.debug("No seeding configuration file passed")
return
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)
if seed_config.personas is not None:

View File

@@ -10,7 +10,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger

View File

@@ -18,7 +18,7 @@ from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id

View File

@@ -28,8 +28,8 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider

View File

@@ -8,8 +8,8 @@ from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail

View File

@@ -5,8 +5,8 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.utils.logger import setup_logger

View File

@@ -9,7 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_user_token_rate_limits
from onyx.db.token_limit import insert_user_token_rate_limit

View File

@@ -16,7 +16,7 @@ from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine import get_session
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger

View File

@@ -78,7 +78,7 @@ def should_continue(state: BasicState) -> str:
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -87,7 +87,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,

View File

@@ -4,7 +4,7 @@ from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,

View File

@@ -111,7 +111,7 @@ def answer_query_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -121,7 +121,7 @@ if __name__ == "__main__":
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)

View File

@@ -238,7 +238,7 @@ def agent_search_graph_builder() -> StateGraph:
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -246,7 +246,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request

View File

@@ -109,7 +109,7 @@ def answer_refined_query_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -119,7 +119,7 @@ if __name__ == "__main__":
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",

View File

@@ -131,7 +131,7 @@ def expanded_retrieval_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -142,7 +142,7 @@ if __name__ == "__main__":
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)

View File

@@ -24,7 +24,7 @@ from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.utils.timing import log_function_time
@@ -60,7 +60,7 @@ def rerank_documents(
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)

View File

@@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
@@ -67,7 +67,7 @@ def retrieve_documents(
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
override_kwargs=SearchToolOverrideKwargs(

View File

@@ -19,7 +19,7 @@ from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.document import get_kg_doc_info_for_entity_name
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.db.entities import get_entity_name
from onyx.db.entity_type import get_entity_types

View File

@@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import create_views
from onyx.db.kg_temp_view import get_user_view_names
from onyx.db.relationships import get_allowed_relationship_type_pairs
@@ -35,6 +35,7 @@ from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -80,10 +81,12 @@ def extract_ert(
stream_write_step_activities(writer, _KG_STEP_NR)
# Create temporary views. TODO: move into parallel step, if ultimately materialized
kg_views = get_user_view_names(user_email)
tenant_id = get_current_tenant_id()
kg_views = get_user_view_names(user_email, tenant_id)
with get_session_with_current_tenant() as db_session:
create_views(
db_session,
tenant_id=tenant_id,
user_email=user_email,
allowed_docs_view_name=kg_views.allowed_docs_view_name,
kg_relationships_view_name=kg_views.kg_relationships_view_name,

View File

@@ -27,7 +27,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_relationships

View File

@@ -29,7 +29,7 @@ from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT_OVERRIDE
from onyx.configs.kg_configs import KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX
from onyx.configs.kg_configs import KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX
from onyx.configs.kg_configs import KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_tenant
from onyx.db.kg_temp_view import drop_views
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT

View File

@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT

View File

@@ -16,7 +16,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQueryPiece
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger

View File

@@ -28,7 +28,7 @@ from onyx.configs.kg_configs import KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERA
from onyx.configs.kg_configs import KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.tools.tool_implementations.search.search_tool import IndexFilters

View File

@@ -5,7 +5,7 @@ from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,

View File

@@ -33,7 +33,7 @@ from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
@@ -195,7 +195,7 @@ if __name__ == "__main__":
query="Do a search to tell me what is the difference between astronomy and astrology?",
)
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
assert (
config.persistence is not None

View File

@@ -56,7 +56,7 @@ from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
@@ -363,7 +363,7 @@ def retrieve_search_docs(
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(

View File

@@ -97,9 +97,9 @@ from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User

View File

@@ -26,7 +26,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector

View File

@@ -11,8 +11,8 @@ import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST

View File

@@ -12,7 +12,7 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -13,7 +13,7 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -13,7 +13,7 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -15,7 +15,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -11,7 +11,7 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -25,8 +25,8 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.redis.redis_connector_credential_pair import (

View File

@@ -38,7 +38,7 @@ from onyx.db.document import (
)
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType

View File

@@ -56,7 +56,7 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus

View File

@@ -23,8 +23,8 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus

View File

@@ -17,7 +17,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.kg.clustering.clustering import kg_clustering
from onyx.kg.extractions.extraction_processing import kg_extraction

View File

@@ -4,7 +4,7 @@ from redis.lock import Lock as RedisLock
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import check_for_documents_needing_kg_processing
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
from onyx.db.models import KGEntityExtractionStaging

View File

@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import LLMProvider
from onyx.db.models import ModelConfiguration

View File

@@ -27,10 +27,10 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType

View File

@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@shared_task(

View File

@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType

View File

@@ -22,7 +22,7 @@ from onyx.db.document import get_document_connector_count
from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.relationships import delete_document_references_from_kg
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index

View File

@@ -21,7 +21,7 @@ from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_with_user_files,
)
from onyx.db.document import get_document
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair

View File

@@ -35,7 +35,7 @@ from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import DocumentSet

View File

@@ -1,7 +1,7 @@
from sqlalchemy.exc import IntegrityError
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
def emit_background_error(

View File

@@ -9,7 +9,7 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.models import IndexAttempt

View File

@@ -16,7 +16,7 @@ from typing import Literal
from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX

View File

@@ -34,7 +34,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus

View File

@@ -79,7 +79,7 @@ from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
@@ -1240,7 +1240,7 @@ def stream_chat_message(
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
start_time = time.time()
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,

View File

@@ -222,17 +222,6 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
# Experimental setting to control idle transactions
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
try:
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
os.environ.get(
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
)
)
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"

View File

@@ -34,7 +34,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext

View File

@@ -23,7 +23,7 @@ from onyx.configs.app_configs import (
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType

View File

@@ -6,7 +6,7 @@ from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import Credential
from onyx.redis.redis_pool import get_redis_client

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext

View File

@@ -27,7 +27,7 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text

View File

@@ -12,7 +12,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -68,7 +68,7 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)

View File

@@ -25,7 +25,7 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)

View File

@@ -17,8 +17,8 @@ from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.api_key import get_api_key_email_pattern
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User

View File

@@ -18,7 +18,7 @@ from onyx.configs.constants import DocumentSource
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
@@ -148,7 +148,7 @@ def get_connector_credential_pairs_for_user_parallel(
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
@@ -208,7 +208,7 @@ def get_cc_pair_groups_for_ids(
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)

View File

@@ -28,7 +28,7 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import delete_from_kg_entities__no_commit
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
from onyx.db.enums import AccessType
@@ -299,7 +299,7 @@ def get_document_counts_for_cc_pairs(
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)

View File

View File

@@ -0,0 +1,141 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from typing import AsyncContextManager
import asyncpg # type: ignore
from fastapi import HTTPException
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.app_configs import POSTGRES_DB
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.db.engine.iam_auth import ssl_context
from onyx.db.engine.sql_engine import ASYNC_DB_API
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.sql_engine import USE_IAM_AUTH
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
"""
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
)
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
use_iam_auth=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
engine_kwargs = {
"connect_args": connect_args,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
if POSTGRES_USE_NULL_POOL:
engine_kwargs["poolclass"] = pool.NullPool
else:
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
_ASYNC_ENGINE = create_async_engine(
connection_string,
**engine_kwargs,
)
if USE_IAM_AUTH:
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
def provide_iam_token_async(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
# For async engine using asyncpg, we still need to set the IAM token here.
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE
async def get_async_session(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
"""For use w/ Depends for *async* FastAPI endpoints.
For standard `async with ... as ...` use, use get_async_session_context_manager.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
engine = get_sqlalchemy_async_engine()
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT:
async with AsyncSession(bind=engine, expire_on_commit=False) as session:
yield session
return
# Create connection with schema translation to handle querying the right schema
schema_translate_map = {None: tenant_id}
async with engine.connect() as connection:
connection = await connection.execution_options(
schema_translate_map=schema_translate_map
)
async with AsyncSession(
bind=connection, expire_on_commit=False
) as async_session:
yield async_session
def get_async_session_context_manager(
tenant_id: str | None = None,
) -> AsyncContextManager[AsyncSession]:
return asynccontextmanager(get_async_session)(tenant_id)

View File

@@ -0,0 +1,27 @@
from sqlalchemy import text
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
sync_postgres_engine = get_sqlalchemy_engine()
connections = [
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
]
for conn in connections:
conn.execute(text("SELECT 1"))
for conn in connections:
conn.close()
async_postgres_engine = get_sqlalchemy_async_engine()
async_connections = [
await async_postgres_engine.connect()
for _ in range(async_connections_to_warm_up)
]
for async_conn in async_connections:
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()

View File

@@ -0,0 +1,56 @@
import os
import ssl
from typing import Any
import boto3
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.constants import SSL_CERT_FILE
def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-east-2"
) -> str:
"""
Generate an IAM authentication token using boto3.
"""
client = boto3.client("rds", region_name=region)
token = client.generate_db_auth_token(
DBHostname=host, Port=int(port), DBUsername=user
)
return token
def configure_psycopg2_iam_auth(
cparams: dict[str, Any], host: str, port: str, user: str, region: str
) -> None:
"""
Configure cparams for psycopg2 with IAM token and SSL.
"""
token = get_iam_auth_token(host, port, user, region)
cparams["password"] = token
cparams["sslmode"] = "require"
cparams["sslrootcert"] = SSL_CERT_FILE
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
if USE_IAM_AUTH:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION_NAME", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()

View File

@@ -1,41 +1,24 @@
import contextlib
import os
import re
import ssl
import threading
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from contextlib import asynccontextmanager
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from typing import AsyncContextManager
import asyncpg # type: ignore
import boto3
from fastapi import HTTPException
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.app_configs import POSTGRES_DB
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
@@ -43,66 +26,34 @@ from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.db.engine.iam_auth import provide_iam_token
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Moved is_valid_schema_name here to avoid circular import
logger = setup_logger()
# Schema name validation (moved here to avoid circular import)
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
# why isn't this in configs?
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()
def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-east-2"
) -> str:
"""
Generate an IAM authentication token using boto3.
"""
client = boto3.client("rds", region_name=region)
token = client.generate_db_auth_token(
DBHostname=host, Port=int(port), DBUsername=user
)
return token
def configure_psycopg2_iam_auth(
cparams: dict[str, Any], host: str, port: str, user: str, region: str
) -> None:
"""
Configure cparams for psycopg2 with IAM token and SSL.
"""
token = get_iam_auth_token(host, port, user, region)
cparams["password"] = token
cparams["sslmode"] = "require"
cparams["sslrootcert"] = SSL_CERT_FILE
def build_connection_string(
*,
@@ -176,17 +127,6 @@ if LOG_POSTGRES_CONN_COUNTS:
logger.debug(f"Total connection checkins: {checkin_count}")
def get_db_current_time(db_session: Session) -> datetime:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None:
raise ValueError("Database did not return a time")
return result
def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None
class SqlEngine:
_engine: Engine | None = None
_readonly_engine: Engine | None = None
@@ -348,33 +288,6 @@ class SqlEngine:
cls._engine = None
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
tenant_ids: list[str]
if not MULTI_TENANT:
return [POSTGRES_DEFAULT_SCHEMA]
with get_session_with_shared_schema() as session:
result = session.execute(
text(
f"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()
@@ -383,83 +296,10 @@ def get_readonly_sqlalchemy_engine() -> Engine:
return SqlEngine.get_readonly_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
"""
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
)
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
use_iam_auth=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
engine_kwargs = {
"connect_args": connect_args,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
if POSTGRES_USE_NULL_POOL:
engine_kwargs["poolclass"] = pool.NullPool
else:
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
_ASYNC_ENGINE = create_async_engine(
connection_string,
**engine_kwargs,
)
if USE_IAM_AUTH:
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
def provide_iam_token_async(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
# For async engine using asyncpg, we still need to set the IAM token here.
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE
engine = get_sqlalchemy_async_engine()
AsyncSessionLocal = sessionmaker( # type: ignore
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
@contextmanager
def get_session_with_current_tenant() -> Generator[Session, None, None]:
"""Standard way to get a DB session."""
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
@@ -473,16 +313,6 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _set_search_path_on_checkout__listener(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
"""Listener to make sure we ALWAYS set the search path on checkout."""
tenant_id = get_current_tenant_id()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
@@ -490,54 +320,29 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
"""
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", _set_search_path_on_checkout__listener)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
# NOTE: don't use `text()` here since we're using the cursor directly
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
except Exception:
raise RuntimeError(f"search_path not set for {tenant_id}")
finally:
cursor.close()
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT:
with Session(bind=engine, expire_on_commit=False) as session:
yield session
return
try:
# automatically rollback or close
with Session(bind=connection, expire_on_commit=False) as session:
yield session
finally:
# always reset the search path on exit
if MULTI_TENANT:
if not dbapi_connection.is_valid:
logger.warning(
"dbapi_connection is None, likely the original connection expired."
)
return
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
except Exception as e:
logger.warning(f"Failed to reset search path: {e}")
connection.rollback()
finally:
cursor.close()
# Create connection with schema translation to handle querying the right schema
schema_translate_map = {None: tenant_id}
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
def get_session() -> Generator[Session, None, None]:
"""For use w/ Depends for FastAPI endpoints.
Has some additional validation, and likely should be merged
with get_session_context_manager in the future."""
with get_session_with_current_tenant in the future."""
tenant_id = get_current_tenant_id()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(detail="User must authenticate")
@@ -545,109 +350,10 @@ def get_session() -> Generator[Session, None, None]:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
yield db_session
@contextlib.contextmanager
def get_session_context_manager() -> Generator[Session, None, None]:
"""Context manager for database sessions."""
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
def _set_search_path_on_transaction__listener(
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
) -> None:
"""Every time a new transaction is started,
set the search_path from the session's info."""
tenant_id = session.info.get("tenant_id")
if tenant_id:
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
async def get_async_session(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
"""For use w/ Depends for *async* FastAPI endpoints.
For standard `async with ... as ...` use, use get_async_session_context_manager.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
# IMPORTANT: do NOT remove. The search_path seems to get reset on every `.commit()`
# without this. Do not fully understand why atm
async_session.info["tenant_id"] = tenant_id
event.listen(
async_session.sync_session,
"after_begin",
_set_search_path_on_transaction__listener,
)
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await async_session.execute(
text(
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# don't need to set the search path for self-hosted + default schema
# this is also true for sync sessions, but just not adding it there for
# now to simplify / not change too much
if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
def get_async_session_context_manager(
tenant_id: str | None = None,
) -> AsyncContextManager[AsyncSession]:
return asynccontextmanager(get_async_session)(tenant_id)
async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
sync_postgres_engine = get_sqlalchemy_engine()
connections = [
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
]
for conn in connections:
conn.execute(text("SELECT 1"))
for conn in connections:
conn.close()
async_postgres_engine = get_sqlalchemy_async_engine()
async_connections = [
await async_postgres_engine.connect()
for _ in range(async_connections_to_warm_up)
]
for async_conn in async_connections:
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
if USE_IAM_AUTH:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION_NAME", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)
@contextmanager
def get_db_readonly_user_session_with_current_tenant() -> (
Generator[Session, None, None]
@@ -660,35 +366,19 @@ def get_db_readonly_user_session_with_current_tenant() -> (
readonly_engine = get_readonly_sqlalchemy_engine()
event.listen(readonly_engine, "checkout", _set_search_path_on_checkout__listener)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with readonly_engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT:
with Session(readonly_engine, expire_on_commit=False) as session:
yield session
return
# no need to use the schema translation map for self-hosted + default schema
schema_translate_map = {None: tenant_id}
with readonly_engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
except Exception as e:
logger.warning(f"Failed to reset search path: {e}")
connection.rollback()
finally:
cursor.close()
yield session

View File

@@ -0,0 +1,33 @@
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
tenant_ids: list[str]
if not MULTI_TENANT:
return [POSTGRES_DEFAULT_SCHEMA]
with get_session_with_shared_schema() as session:
result = session.execute(
text(
f"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants

View File

@@ -0,0 +1,11 @@
from datetime import datetime
from sqlalchemy import text
from sqlalchemy.orm import Session
def get_db_current_time(db_session: Session) -> datetime:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None:
raise ValueError("Database did not return a time")
return result

View File

@@ -16,7 +16,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import ConnectorCredentialPair

View File

@@ -7,24 +7,27 @@ from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.kg_configs import KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX
from onyx.configs.kg_configs import KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX
from onyx.configs.kg_configs import KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant
Base = declarative_base()
def get_user_view_names(user_email: str) -> KGViewNames:
user_email_cleaned = user_email.replace("@", "_").replace(".", "_")
def get_user_view_names(user_email: str, tenant_id: str) -> KGViewNames:
user_email_cleaned = (
user_email.replace("@", "__").replace(".", "_").replace("+", "_")
)
return KGViewNames(
allowed_docs_view_name=f"{KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX}_{user_email_cleaned}",
kg_relationships_view_name=f"{KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX}_{user_email_cleaned}",
kg_entity_view_name=f"{KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX}_{user_email_cleaned}",
allowed_docs_view_name=f'"{tenant_id}".{KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX}_{user_email_cleaned}',
kg_relationships_view_name=f'"{tenant_id}".{KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX}_{user_email_cleaned}',
kg_entity_view_name=f'"{tenant_id}".{KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX}_{user_email_cleaned}',
)
# First, create the view definition
def create_views(
db_session: Session,
tenant_id: str,
user_email: str,
allowed_docs_view_name: str,
kg_relationships_view_name: str,
@@ -37,24 +40,24 @@ def create_views(
CREATE OR REPLACE VIEW {allowed_docs_view_name} AS
WITH kg_used_docs AS (
SELECT document_id as kg_used_doc_id
FROM kg_entity d
FROM "{tenant_id}".kg_entity d
WHERE document_id IS NOT NULL
),
public_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
FROM "{tenant_id}".document d
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE d.is_public
),
user_owned_docs AS (
SELECT d.id as allowed_doc_id
FROM document_by_connector_credential_pair d
JOIN credential c ON d.credential_id = c.id
JOIN connector_credential_pair ccp ON
FROM "{tenant_id}".document_by_connector_credential_pair d
JOIN "{tenant_id}".credential c ON d.credential_id = c.id
JOIN "{tenant_id}".connector_credential_pair ccp ON
d.connector_id = ccp.connector_id AND
d.credential_id = ccp.credential_id
JOIN "user" u ON c.user_id = u.id
JOIN "{tenant_id}".user u ON c.user_id = u.id
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE ccp.status != 'DELETING'
AND ccp.access_type != 'SYNC'
@@ -62,15 +65,15 @@ def create_views(
),
user_group_accessible_docs AS (
SELECT d.id as allowed_doc_id
FROM document_by_connector_credential_pair d
JOIN connector_credential_pair ccp ON
FROM "{tenant_id}".document_by_connector_credential_pair d
JOIN "{tenant_id}".connector_credential_pair ccp ON
d.connector_id = ccp.connector_id AND
d.credential_id = ccp.credential_id
JOIN user_group__connector_credential_pair ugccp ON
JOIN "{tenant_id}".user_group__connector_credential_pair ugccp ON
ccp.id = ugccp.cc_pair_id
JOIN user__user_group uug ON
JOIN "{tenant_id}".user__user_group uug ON
uug.user_group_id = ugccp.user_group_id
JOIN "user" u ON uug.user_id = u.id
JOIN "{tenant_id}".user u ON uug.user_id = u.id
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE kud.kg_used_doc_id IS NOT NULL
AND ccp.status != 'DELETING'
@@ -79,17 +82,17 @@ def create_views(
),
external_user_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
FROM "{tenant_id}".document d
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE kud.kg_used_doc_id IS NOT NULL
AND :user_email = ANY(external_user_emails)
),
external_group_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
FROM "{tenant_id}".document d
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
JOIN user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids)
JOIN "user" u ON ueg.user_id = u.id
JOIN "{tenant_id}".user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids)
JOIN "{tenant_id}".user u ON ueg.user_id = u.id
WHERE kud.kg_used_doc_id IS NOT NULL
AND u.email = :user_email
)
@@ -122,11 +125,11 @@ def create_views(
d.doc_updated_at as source_date,
se.attributes as source_entity_attributes,
te.attributes as target_entity_attributes
FROM kg_relationship kgr
FROM "{tenant_id}".kg_relationship kgr
INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kgr.source_document
JOIN document d on d.id = kgr.source_document
JOIN kg_entity se on se.id_name = kgr.source_node
JOIN kg_entity te on te.id_name = kgr.target_node
JOIN "{tenant_id}".document d on d.id = kgr.source_document
JOIN "{tenant_id}".kg_entity se on se.id_name = kgr.source_node
JOIN "{tenant_id}".kg_entity te on te.id_name = kgr.target_node
"""
)
@@ -139,9 +142,9 @@ def create_views(
kge.attributes as entity_attributes,
kge.document_id as source_document,
d.doc_updated_at as source_date
FROM kg_entity kge
FROM "{tenant_id}".kg_entity kge
INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kge.document_id
JOIN document d on d.id = kge.document_id
JOIN "{tenant_id}".document d on d.id = kge.document_id
"""
)

Some files were not shown because too many files have changed in this diff Show More