mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-27 04:35:50 +00:00
Compare commits
1 Commits
csv_render
...
v0.28.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f1fa5c208 |
1
.github/workflows/pr-integration-tests.yml
vendored
1
.github/workflows/pr-integration-tests.yml
vendored
@@ -159,6 +159,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -16,18 +25,42 @@ logger = setup_logger()
|
||||
# mark as EE for all tasks in this file
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
@@ -35,11 +68,26 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_cloud_tasks as base_beat_system_tasks,
|
||||
)
|
||||
@@ -34,7 +35,7 @@ ee_beat_task_templates.extend(
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -69,7 +70,7 @@ if not MULTI_TENANT:
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
|
||||
@@ -9,7 +9,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def should_perform_chat_ttl_check(
|
||||
retention_limit_days: int | None, db_session: Session
|
||||
retention_limit_days: float | None, db_session: Session
|
||||
) -> bool:
|
||||
# TODO: make this a check for None and add behavior for 0 day TTL
|
||||
if not retention_limit_days:
|
||||
|
||||
@@ -6,7 +6,9 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
QUERY_HISTORY_TASK_NAME_PREFIX = OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK
|
||||
|
||||
|
||||
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
|
||||
def name_chat_ttl_task(
|
||||
retention_limit_days: float, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
|
||||
@@ -71,6 +71,14 @@ SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
####
|
||||
# Celery Job Frequency
|
||||
####
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
os.environ.get("CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS") or 1
|
||||
) # float for easier testing
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
|
||||
@@ -1,124 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from celery import Task
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.tasks import mark_task_finished
|
||||
from onyx.db.tasks import mark_task_start
|
||||
from onyx.db.tasks import register_task
|
||||
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]:
|
||||
"""Utility meant to wrap the celery task `run` function in order to
|
||||
automatically update our custom `task_queue_jobs` table appropriately"""
|
||||
|
||||
def wrap_task_fn(task_fn: T) -> T:
|
||||
@wraps(task_fn)
|
||||
def wrapped_task_fn(*args: list, **kwargs: dict) -> Any:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
task_name = build_name_fn(*args, **kwargs)
|
||||
with Session(engine) as db_session:
|
||||
# mark the task as started
|
||||
mark_task_start(task_name=task_name, db_session=db_session)
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
try:
|
||||
result = task_fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
with Session(engine) as db_session:
|
||||
mark_task_finished(
|
||||
task_name=task_name,
|
||||
db_session=db_session,
|
||||
success=exception is None,
|
||||
)
|
||||
|
||||
if not exception:
|
||||
return result
|
||||
else:
|
||||
raise exception
|
||||
|
||||
return cast(T, wrapped_task_fn)
|
||||
|
||||
return wrap_task_fn
|
||||
|
||||
|
||||
# rough type signature for `apply_async`
|
||||
AA = TypeVar("AA", bound=Callable[..., AsyncResult])
|
||||
|
||||
|
||||
def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]:
|
||||
"""Utility meant to wrap celery `apply_async` function in order to automatically
|
||||
update create an entry in our `task_queue_jobs` table"""
|
||||
|
||||
def wrapper(fn: AA) -> AA:
|
||||
@wraps(fn)
|
||||
def wrapped_fn(
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
*other_args: list,
|
||||
**other_kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
# `apply_async` takes in args / kwargs directly as arguments
|
||||
args_for_build_name = args or tuple()
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# register_task must come before fn = apply_async or else the task
|
||||
# might run mark_task_start (and crash) before the task row exists
|
||||
db_task = register_task(task_name, db_session)
|
||||
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
|
||||
# we update the celery task id for diagnostic purposes
|
||||
# but it isn't currently used by any code
|
||||
db_task.task_id = task.id
|
||||
db_session.commit()
|
||||
|
||||
return task
|
||||
|
||||
return cast(AA, wrapped_fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_celery_task_wrapper(
|
||||
build_name_fn: Callable[..., str],
|
||||
) -> Callable[[Task], Task]:
|
||||
"""Utility meant to wrap celery task functions in order to automatically
|
||||
update our custom `task_queue_jobs` table appropriately.
|
||||
|
||||
On task creation (e.g. `apply_async`), a row is inserted into the table with
|
||||
status `PENDING`.
|
||||
On task start, the latest row is updated to have status `STARTED`.
|
||||
On task success, the latest row is updated to have status `SUCCESS`.
|
||||
On the task raising an unhandled exception, the latest row is updated to have
|
||||
status `FAILURE`.
|
||||
"""
|
||||
|
||||
def wrap_task(task: Task) -> Task:
|
||||
task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore
|
||||
task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore
|
||||
return task
|
||||
|
||||
return wrap_task
|
||||
|
||||
|
||||
def construct_query_history_report_name(
|
||||
task_id: str,
|
||||
) -> str:
|
||||
|
||||
@@ -445,7 +445,11 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
|
||||
# chat retention
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
PERFORM_TTL_MANAGEMENT_TASK = "perform_ttl_management_task"
|
||||
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import cast
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import Tag
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -13,7 +12,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -69,7 +68,7 @@ class GoogleSitesConnector(LoadConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
file_content_io = get_default_file_store(db_session).read_file(
|
||||
self.zip_path, mode="b"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from onyx import __version__
|
||||
@@ -46,6 +45,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.engine import warm_up_connections
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
@@ -206,7 +206,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
)
|
||||
engine = SqlEngine.get_engine()
|
||||
SqlEngine.get_engine()
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"onyx.auth.users", "verify_auth_setting"
|
||||
@@ -230,7 +230,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
get_or_generate_uuid()
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
@@ -43,7 +43,8 @@ class Notification(BaseModel):
|
||||
class Settings(BaseModel):
|
||||
"""General settings"""
|
||||
|
||||
maximum_chat_retention_days: int | None = None
|
||||
# is float to allow for fractional days for easier automated testing
|
||||
maximum_chat_retention_days: float | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
@@ -174,7 +174,8 @@ class DATestGatingType(str, Enum):
|
||||
class DATestSettings(BaseModel):
|
||||
"""General settings"""
|
||||
|
||||
maximum_chat_retention_days: int | None = None
|
||||
# is float to allow for fractional days for easier automated testing
|
||||
maximum_chat_retention_days: float | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: DATestGatingType = DATestGatingType.NONE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.settings import SettingsManager
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat retention tests are enterprise only",
|
||||
)
|
||||
def test_chat_retention(reset: None, admin_user: DATestUser) -> None:
|
||||
"""Test that chat sessions are deleted after the retention period expires."""
|
||||
|
||||
# Set chat retention period to 10 seconds
|
||||
retention_days = 10 / 86400 # 10 seconds in days (10 / 24 / 60 / 60)
|
||||
settings = DATestSettings(maximum_chat_retention_days=retention_days)
|
||||
SettingsManager.update_settings(settings, user_performing_action=admin_user)
|
||||
|
||||
# Create a chat session
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Test chat retention",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Send a message
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="This message should be deleted soon",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify the chat session exists
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(chat_history) > 0, "Chat session should have messages"
|
||||
|
||||
# Wait for TTL task to run (give it ~60 seconds)
|
||||
print("Waiting for chat retention TTL task to run...")
|
||||
max_wait_time = 60 # maximum time to wait in seconds
|
||||
start_time = time.time()
|
||||
session_deleted = False
|
||||
|
||||
while not session_deleted and (time.time() - start_time < max_wait_time):
|
||||
# Check if chat session is deleted
|
||||
try:
|
||||
# Attempt to get chat history - this should 404
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# If we got no messages or an empty response, session might be deleted
|
||||
if not chat_history:
|
||||
session_deleted = True
|
||||
break
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
# If we get a 404 or other error, the session is gone
|
||||
if e.response.status_code in (404, 400):
|
||||
session_deleted = True
|
||||
break
|
||||
raise # Re-raise other errors
|
||||
|
||||
# Wait a bit before checking again
|
||||
time.sleep(5)
|
||||
print(f"Waited {time.time() - start_time:.1f} seconds for chat deletion...")
|
||||
|
||||
# Assert that the chat session was deleted
|
||||
assert session_deleted, "Chat session was not deleted within the expected time"
|
||||
@@ -256,6 +256,8 @@ services:
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# primarily for testing
|
||||
- CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=${CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS:-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
|
||||
Reference in New Issue
Block a user