Compare commits

...

1 Commits

Author SHA1 Message Date
Chris Weaver
7f1fa5c208 Non default schema fix (#4667)
* Use correct postgres schema

* Remove raw Session() use

* Refactor + add test

* Fix comment
2025-05-07 10:25:09 -07:00
14 changed files with 173 additions and 146 deletions

View File

@@ -159,6 +159,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker

View File

@@ -1,13 +1,22 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -16,18 +25,42 @@ logger = setup_logger()
# mark as EE for all tasks in this file
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
@shared_task(
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def perform_ttl_management_task(
self: Task, retention_limit_days: int, *, tenant_id: str
) -> None:
task_id = self.request.id
if not task_id:
raise RuntimeError("No task id defined for this task; cannot identify it")
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
start_time = datetime.now(tz=timezone.utc)
user_id: UUID | None = None
session_id: UUID | None = None
try:
with get_session_with_current_tenant() as db_session:
try:
# we generally want to move off this, but keeping for now
register_task(
db_session=db_session,
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
@@ -35,11 +68,26 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
#####

View File

@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any
from ee.onyx.configs.app_configs import CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS
from onyx.background.celery.tasks.beat_schedule import (
beat_cloud_tasks as base_beat_system_tasks,
)
@@ -34,7 +35,7 @@ ee_beat_task_templates.extend(
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -69,7 +70,7 @@ if not MULTI_TENANT:
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,

View File

@@ -9,7 +9,7 @@ logger = setup_logger()
def should_perform_chat_ttl_check(
retention_limit_days: int | None, db_session: Session
retention_limit_days: float | None, db_session: Session
) -> bool:
# TODO: make this a check for None and add behavior for 0 day TTL
if not retention_limit_days:

View File

@@ -6,7 +6,9 @@ from onyx.configs.constants import OnyxCeleryTask
QUERY_HISTORY_TASK_NAME_PREFIX = OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
def name_chat_ttl_task(
retention_limit_days: float, tenant_id: str | None = None
) -> str:
return f"chat_ttl_{retention_limit_days}_days"

View File

@@ -71,6 +71,14 @@ SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
####
# Celery Job Frequency
####
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
os.environ.get("CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS") or 1
) # float for easier testing
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")

View File

@@ -1,124 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
from celery import Task
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.tasks import mark_task_finished
from onyx.db.tasks import mark_task_start
from onyx.db.tasks import register_task
QUERY_REPORT_NAME_PREFIX = "query-history"
T = TypeVar("T", bound=Callable)
def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]:
"""Utility meant to wrap the celery task `run` function in order to
automatically update our custom `task_queue_jobs` table appropriately"""
def wrap_task_fn(task_fn: T) -> T:
@wraps(task_fn)
def wrapped_task_fn(*args: list, **kwargs: dict) -> Any:
engine = get_sqlalchemy_engine()
task_name = build_name_fn(*args, **kwargs)
with Session(engine) as db_session:
# mark the task as started
mark_task_start(task_name=task_name, db_session=db_session)
result = None
exception = None
try:
result = task_fn(*args, **kwargs)
except Exception as e:
exception = e
with Session(engine) as db_session:
mark_task_finished(
task_name=task_name,
db_session=db_session,
success=exception is None,
)
if not exception:
return result
else:
raise exception
return cast(T, wrapped_task_fn)
return wrap_task_fn
# rough type signature for `apply_async`
AA = TypeVar("AA", bound=Callable[..., AsyncResult])
def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]:
"""Utility meant to wrap celery `apply_async` function in order to automatically
update create an entry in our `task_queue_jobs` table"""
def wrapper(fn: AA) -> AA:
@wraps(fn)
def wrapped_fn(
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
*other_args: list,
**other_kwargs: dict[str, Any],
) -> Any:
# `apply_async` takes in args / kwargs directly as arguments
args_for_build_name = args or tuple()
kwargs_for_build_name = kwargs or {}
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
with Session(get_sqlalchemy_engine()) as db_session:
# register_task must come before fn = apply_async or else the task
# might run mark_task_start (and crash) before the task row exists
db_task = register_task(task_name, db_session)
task = fn(args, kwargs, *other_args, **other_kwargs)
# we update the celery task id for diagnostic purposes
# but it isn't currently used by any code
db_task.task_id = task.id
db_session.commit()
return task
return cast(AA, wrapped_fn)
return wrapper
def build_celery_task_wrapper(
build_name_fn: Callable[..., str],
) -> Callable[[Task], Task]:
"""Utility meant to wrap celery task functions in order to automatically
update our custom `task_queue_jobs` table appropriately.
On task creation (e.g. `apply_async`), a row is inserted into the table with
status `PENDING`.
On task start, the latest row is updated to have status `STARTED`.
On task success, the latest row is updated to have status `SUCCESS`.
On the task raising an unhandled exception, the latest row is updated to have
status `FAILURE`.
"""
def wrap_task(task: Task) -> Task:
task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore
task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore
return task
return wrap_task
def construct_query_history_report_name(
task_id: str,
) -> str:

View File

@@ -445,7 +445,11 @@ class OnyxCeleryTask:
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
# chat retention
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
PERFORM_TTL_MANAGEMENT_TASK = "perform_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"

View File

@@ -5,7 +5,6 @@ from typing import cast
from bs4 import BeautifulSoup
from bs4 import Tag
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -13,7 +12,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.engine import get_session_context_manager
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -69,7 +68,7 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_context_manager() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)

View File

@@ -20,7 +20,6 @@ from httpx_oauth.clients.google import GoogleOAuth2
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
from starlette.types import Lifespan
from onyx import __version__
@@ -46,6 +45,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import SqlEngine
from onyx.db.engine import warm_up_connections
from onyx.server.api_key.api import router as api_key_router
@@ -206,7 +206,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
)
engine = SqlEngine.get_engine()
SqlEngine.get_engine()
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
@@ -230,7 +230,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
get_or_generate_uuid()
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:
with get_session_context_manager() as db_session:
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
else:
setup_multitenant_onyx()

View File

@@ -43,7 +43,8 @@ class Notification(BaseModel):
class Settings(BaseModel):
"""General settings"""
maximum_chat_retention_days: int | None = None
# is float to allow for fractional days for easier automated testing
maximum_chat_retention_days: float | None = None
gpu_enabled: bool | None = None
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None

View File

@@ -174,7 +174,8 @@ class DATestGatingType(str, Enum):
class DATestSettings(BaseModel):
"""General settings"""
maximum_chat_retention_days: int | None = None
# is float to allow for fractional days for easier automated testing
maximum_chat_retention_days: float | None = None
gpu_enabled: bool | None = None
product_gating: DATestGatingType = DATestGatingType.NONE
anonymous_user_enabled: bool | None = None

View File

@@ -0,0 +1,78 @@
import os
import time
import pytest
import requests
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.settings import SettingsManager
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat retention tests are enterprise only",
)
def test_chat_retention(reset: None, admin_user: DATestUser) -> None:
"""Test that chat sessions are deleted after the retention period expires."""
# Set chat retention period to 10 seconds
retention_days = 10 / 86400 # 10 seconds in days (10 / 24 / 60 / 60)
settings = DATestSettings(maximum_chat_retention_days=retention_days)
SettingsManager.update_settings(settings, user_performing_action=admin_user)
# Create a chat session
chat_session = ChatSessionManager.create(
persona_id=0,
description="Test chat retention",
user_performing_action=admin_user,
)
# Send a message
ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="This message should be deleted soon",
user_performing_action=admin_user,
)
# Verify the chat session exists
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
assert len(chat_history) > 0, "Chat session should have messages"
# Wait for TTL task to run (give it ~60 seconds)
print("Waiting for chat retention TTL task to run...")
max_wait_time = 60 # maximum time to wait in seconds
start_time = time.time()
session_deleted = False
while not session_deleted and (time.time() - start_time < max_wait_time):
# Check if chat session is deleted
try:
# Attempt to get chat history - this should 404
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
# If we got no messages or an empty response, session might be deleted
if not chat_history:
session_deleted = True
break
except requests.exceptions.HTTPError as e:
# If we get a 404 or other error, the session is gone
if e.response.status_code in (404, 400):
session_deleted = True
break
raise # Re-raise other errors
# Wait a bit before checking again
time.sleep(5)
print(f"Waited {time.time() - start_time:.1f} seconds for chat deletion...")
# Assert that the chat session was deleted
assert session_deleted, "Chat session was not deleted within the expected time"

View File

@@ -256,6 +256,8 @@ services:
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# primarily for testing
- CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=${CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS:-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro