Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
852c0d7a8c stashing 2025-02-06 16:58:14 -08:00
375 changed files with 4086 additions and 7718 deletions

View File

@@ -65,7 +65,6 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true

View File

@@ -12,32 +12,7 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# 1) Preliminary job to check if the changed files are relevant
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
steps:
@@ -77,8 +52,6 @@ jobs:
provenance: false
build-arm64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
steps:
@@ -118,8 +91,7 @@ jobs:
provenance: false
merge-and-scan:
needs: [build-amd64, build-arm64, check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
needs: [build-amd64, build-arm64]
runs-on: ubuntu-latest
steps:
- name: Login to Docker Hub

View File

@@ -94,19 +94,16 @@ jobs:
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
@@ -122,10 +119,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
@@ -133,17 +126,17 @@ jobs:
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
echo "Multi-tenant integration tests failed. Exiting with error."
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All multi-tenant integration tests passed successfully."
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Start Docker containers
run: |
@@ -223,30 +216,27 @@ jobs:
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -133,4 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)

View File

@@ -1,32 +0,0 @@
"""set built in to default
Revision ID: 2cdeff6d8c93
Revises: f5437cc136c5
Create Date: 2025-02-11 14:57:51.308775
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2cdeff6d8c93"
down_revision = "f5437cc136c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Prior to this migration / point in the codebase history,
# built in personas were implicitly treated as default personas (with no option to change this)
# This migration makes that explicit
op.execute(
"""
UPDATE persona
SET is_default_persona = TRUE
WHERE builtin_persona = TRUE
"""
)
def downgrade() -> None:
pass

View File

@@ -5,6 +5,7 @@ Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
@@ -19,8 +20,13 @@ down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
@@ -57,6 +63,7 @@ def upgrade() -> None:
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
@@ -64,12 +71,15 @@ def upgrade() -> None:
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
@@ -160,9 +170,10 @@ def upgrade() -> None:
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
pass
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
@@ -179,6 +190,8 @@ def upgrade() -> None:
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
@@ -260,7 +273,7 @@ def downgrade() -> None:
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
pass
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")

View File

@@ -1,40 +0,0 @@
"""Add background errors table
Revision ID: f39c5794c10a
Revises: 2cdeff6d8c93
Create Date: 2025-02-12 17:11:14.527876
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f39c5794c10a"
down_revision = "2cdeff6d8c93"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"background_error",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("message", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
ondelete="CASCADE",
),
)
def downgrade() -> None:
op.drop_table("background_error")

View File

@@ -1,46 +1,44 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import (
beat_cloud_tasks as base_beat_system_tasks,
)
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
beat_task_templates as base_beat_task_templates,
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
from onyx.background.celery.tasks.beat_schedule import (
get_tasks_to_schedule as base_get_tasks_to_schedule,
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = []
ee_beat_task_templates.extend(
[
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
ee_cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
},
]
)
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
ee_tasks_to_schedule: list[dict] = []
@@ -67,14 +65,9 @@ if not MULTI_TENANT:
]
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
cloud_tasks = generate_cloud_tasks(
beat_system_tasks, beat_task_templates, beat_multiplier
)
return cloud_tasks
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
return ee_tasks_to_schedule + base_tasks_to_schedule

View File

@@ -77,5 +77,3 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
GATED_TENANTS_KEY = "gated_tenants"

View File

@@ -15,9 +15,6 @@ def make_persona_private(
group_ids: list[int] | None,
db_session: Session,
) -> None:
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -26,22 +23,19 @@ def make_persona_private(
).delete(synchronize_session="fetch")
if user_ids:
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
for user_uuid in user_ids:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
create_notification(
user_id=user_id,
user_id=user_uuid,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
if group_ids:
group_ids_set = set(group_ids)
for group_id in group_ids_set:
for group_id in group_ids:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)

View File

@@ -365,9 +365,7 @@ def confluence_doc_sync(
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():

View File

@@ -1,6 +1,5 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
@@ -11,7 +10,7 @@ logger = setup_logger()
def _build_group_member_email_map(
confluence_client: OnyxConfluence, cc_pair_id: int
confluence_client: OnyxConfluence,
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
@@ -19,11 +18,8 @@ def _build_group_member_email_map(
user = user_result.get("user", {})
if not user:
msg = f"user result missing user field: {user_result}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
logger.warning(f"user result missing user field: {user_result}")
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server
@@ -36,12 +32,7 @@ def _build_group_member_email_map(
)
if not email:
# If we still don't have an email, skip this user
msg = f"user result missing email field: {user_result}"
if user.get("type") == "app":
logger.warning(msg)
else:
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
@@ -51,18 +42,11 @@ def _build_group_member_email_map(
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not all_users_groups:
msg = f"No groups found for user with email: {email}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
if not group_member_emails:
msg = "No groups found for any users."
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
return group_member_emails
@@ -77,7 +61,6 @@ def confluence_group_sync(
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,
cc_pair_id=cc_pair.id,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()

View File

@@ -15,7 +15,6 @@ logger = setup_logger()
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
gmail_connector: GmailConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -25,9 +24,7 @@ def _get_slim_doc_generator(
)
return gmail_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,
start=start_time, end=current_time.timestamp()
)
@@ -43,9 +40,7 @@ def gmail_doc_sync(
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback
)
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:

View File

@@ -21,7 +21,6 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
google_drive_connector: GoogleDriveConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -31,9 +30,7 @@ def _get_slim_doc_generator(
)
return google_drive_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,
start=start_time, end=current_time.timestamp()
)

View File

@@ -20,11 +20,19 @@ def _get_slack_document_ids_and_channels(
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -32,14 +40,6 @@ def _get_slack_document_ids_and_channels(
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
return channel_doc_map

View File

@@ -64,7 +64,6 @@ async def _get_tenant_id_from_request(
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:

View File

@@ -83,7 +83,6 @@ def handle_search_request(
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)

View File

@@ -18,16 +18,11 @@ from ee.onyx.server.tenants.anonymous_user_path import (
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
@@ -44,9 +39,12 @@ from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -128,29 +126,37 @@ async def login_as_anonymous_user(
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
) -> None:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
tenant_id = product_gating_request.tenant_id
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)
if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@router.get("/billing-information")
@router.get("/billing-information", response_model=BillingInformation)
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
) -> BillingInformation:
logger.info("Fetching billing information")
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
return BillingInformation(
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
)
@router.post("/create-customer-portal-session")
@@ -163,10 +169,9 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
)
logger.info(portal_session)
return {"url": portal_session.url}
@@ -175,20 +180,6 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,

View File

@@ -6,7 +6,6 @@ import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -15,19 +14,6 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
token = generate_data_plane_token()
headers = {
@@ -41,7 +27,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> BillingInformation:
def fetch_billing_information(tenant_id: str) -> dict:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -52,7 +38,7 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = BillingInformation(**response.json())
billing_info = response.json()
return billing_info

View File

@@ -1,8 +1,7 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.server.settings.models import ApplicationStatus
from onyx.configs.constants import NotificationType
from onyx.server.settings.models import GatingType
class CheckoutSessionCreationRequest(BaseModel):
@@ -16,24 +15,15 @@ class CreateTenantRequest(BaseModel):
class ProductGatingRequest(BaseModel):
tenant_id: str
application_status: ApplicationStatus
class SubscriptionStatusResponse(BaseModel):
subscribed: bool
product_gating: GatingType
notification: NotificationType | None = None
class BillingInformation(BaseModel):
stripe_subscription_id: str
status: str
current_period_start: datetime
current_period_end: datetime
number_of_seats: int
cancel_at_period_end: bool
canceled_at: datetime | None
trial_start: datetime | None
trial_end: datetime | None
seats: int
subscription_status: str
billing_start: str
billing_end: str
payment_method_enabled: bool
@@ -58,12 +48,3 @@ class TenantDeletionPayload(BaseModel):
class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None
class ProductGatingResponse(BaseModel):
updated: bool
error: str | None
class SubscriptionSessionResponse(BaseModel):
sessionId: str

View File

@@ -1,51 +0,0 @@
from typing import cast
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.setup import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
else:
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
settings.application_status = application_status
store_settings(settings)
# Store gated tenant information in Redis
update_tenant_gating(tenant_id, application_status)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception:
logger.exception("Failed to gate product")
raise
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -24,7 +24,6 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
@@ -86,8 +85,7 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)

View File

@@ -28,9 +28,3 @@ class EmbeddingModelTextType:
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:
CUDA = "cuda"
MAC_MPS = "mps"
NONE = "none"

View File

@@ -12,7 +12,6 @@ import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm.exceptions import RateLimitError
@@ -321,7 +320,6 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
logger.error("Empty strings provided for embedding")
@@ -375,11 +373,8 @@ async def embed_text(
elapsed = time.monotonic() - start
logger.info(
f"event=embedding_provider "
f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with provider {provider_type} in {elapsed:.2f}"
)
elif model_name is not None:
logger.info(
@@ -408,14 +403,6 @@ async def embed_text(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}"
)
logger.info(
f"event=embedding_model "
f"texts={len(texts)} "
f"chars={total_chars} "
f"model={model_name} "
f"gpu={gpu_type} "
f"elapsed={elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
@@ -468,15 +455,8 @@ async def litellm_rerank(
@router.post("/bi-encoder-embed")
async def route_bi_encoder_embed(
request: Request,
embed_request: EmbedRequest,
) -> EmbedResponse:
return await process_embed_request(embed_request, request.app.state.gpu_type)
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
embed_request: EmbedRequest,
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
@@ -504,7 +484,6 @@ async def process_embed_request(
api_url=embed_request.api_url,
api_version=embed_request.api_version,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:

View File

@@ -16,7 +16,6 @@ from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_ONLY
@@ -59,10 +58,12 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
gpu_type = get_gpu_type()
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
app.state.gpu_type = gpu_type
if torch.cuda.is_available():
logger.notice("CUDA GPU is available")
elif torch.backends.mps.is_available():
logger.notice("Mac MPS is available")
else:
logger.notice("GPU is not available, using CPU")
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")

View File

@@ -1,9 +1,7 @@
import torch
from fastapi import APIRouter
from fastapi import Response
from model_server.constants import GPUStatus
from model_server.utils import get_gpu_type
router = APIRouter(prefix="/api")
@@ -13,7 +11,10 @@ async def healthcheck() -> Response:
@router.get("/gpu-status")
async def route_gpu_status() -> dict[str, bool | str]:
gpu_type = get_gpu_type()
gpu_available = gpu_type != GPUStatus.NONE
return {"gpu_available": gpu_available, "type": gpu_type}
async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():
return {"gpu_available": True, "type": "mps"}
else:
return {"gpu_available": False, "type": "none"}

View File

@@ -8,9 +8,6 @@ from typing import Any
from typing import cast
from typing import TypeVar
import torch
from model_server.constants import GPUStatus
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -61,12 +58,3 @@ def simple_log_function_time(
return cast(F, wrapped_sync_func)
return decorator
def get_gpu_type() -> str:
if torch.cuda.is_available():
return GPUStatus.CUDA
if torch.backends.mps.is_available():
return GPUStatus.MAC_MPS
return GPUStatus.NONE

View File

@@ -21,11 +21,10 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
def rerank_documents(
@@ -40,8 +39,6 @@ def rerank_documents(
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
# If no question defined/question is None in the state, use the original
# question from the search request as query
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
@@ -50,28 +47,39 @@ def rerank_documents(
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
# skip section filtering
if rerank_settings is None:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
_search_query,
verified_documents,
)
else:
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")

View File

@@ -23,7 +23,6 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -68,12 +67,9 @@ def retrieve_documents(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -58,7 +58,6 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
@@ -219,10 +218,7 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm,
)
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
@@ -345,12 +341,8 @@ def retrieve_search_docs(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -1,7 +1,7 @@
import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -13,150 +13,23 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width" />
<title>{title}</title>
<style>
body, table, td, a {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
text-size-adjust: 100%;
margin: 0;
padding: 0;
-webkit-font-smoothing: antialiased;
-webkit-text-size-adjust: none;
}}
body {{
background-color: #f7f7f7;
color: #333;
}}
.body-content {{
color: #333;
}}
.email-container {{
width: 100%;
max-width: 600px;
margin: 0 auto;
background-color: #ffffff;
border-radius: 6px;
overflow: hidden;
border: 1px solid #eaeaea;
}}
.header {{
background-color: #000000;
padding: 20px;
text-align: center;
}}
.header img {{
max-width: 140px;
}}
.body-content {{
padding: 20px 30px;
}}
.title {{
font-size: 20px;
font-weight: bold;
margin: 0 0 10px;
}}
.message {{
font-size: 16px;
line-height: 1.5;
margin: 0 0 20px;
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
}}
.footer {{
font-size: 13px;
color: #6A7280;
text-align: center;
padding: 20px;
}}
.footer a {{
color: #6b7280;
text-decoration: underline;
}}
</style>
</head>
<body>
<table role="presentation" class="email-container" cellpadding="0" cellspacing="0">
<tr>
<td class="header">
<img
style="background-color: #ffffff; border-radius: 8px;"
src="https://www.onyx.app/logos/customer/onyx.png"
alt="Onyx Logo"
>
</td>
</tr>
<tr>
<td class="body-content">
<h1 class="title">{heading}</h1>
<div class="message">
{message}
</div>
{cta_block}
</td>
</tr>
<tr>
<td class="footer">
© {year} Onyx. All rights reserved.
<br>
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
</td>
</tr>
</table>
</body>
</html>
"""
def build_html_email(
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
) -> str:
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
else:
cta_block = ""
return HTML_EMAIL_TEMPLATE.format(
title=heading,
heading=heading,
message=message,
cta_block=cta_block,
year=datetime.now().year,
)
def send_email(
user_email: str,
subject: str,
html_body: str,
text_body: str,
body: str,
mail_from: str = EMAIL_FROM,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
msg = MIMEMultipart("alternative")
msg = MIMEMultipart()
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
msg.attach(part_text)
msg.attach(part_html)
msg.attach(MIMEText(body))
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -167,44 +40,26 @@ def send_email(
raise e
def send_subscription_cancellation_email(user_email: str) -> None:
# Example usage of the reusable HTML
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>Were sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
cta_text = "Renew Subscription"
cta_link = "https://www.onyx.app/pricing"
html_content = build_html_email(heading, message, cta_text, cta_link)
text_content = (
"We're sorry to see you go.\n"
"Your subscription has been canceled and will end on your next billing date.\n"
"If you change your mind, visit https://www.onyx.app/pricing"
)
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
body = dedent(
f"""\
Hello,
You have been invited to join an organization on Onyx.
To join the organization, please visit the following link:
{WEB_DOMAIN}/auth/signup?email={user_email}
You'll be asked to set a password or login with Google to complete your registration.
Best regards,
The Onyx Team
"""
)
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
send_email(user_email, subject, html_content, text_content)
send_email(user_email, subject, body, current_user.email)
def send_forgot_password_email(
@@ -213,15 +68,13 @@ def send_forgot_password_email(
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if tenant_id:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)
text_content = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
# Keep search param same name as cookie for simplicity
body = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, body, mail_from)
def send_user_verification_email(
@@ -229,12 +82,7 @@ def send_user_verification_email(
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
subject = "Onyx Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)
html_content = build_html_email("Verify Your Email", message)
text_content = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
body = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, body, mail_from)

View File

@@ -1,56 +1,41 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
task_logger = get_task_logger(__name__)
logger = setup_logger(__name__)
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
"""This scheduler is useful because we can dynamically adjust task generation rates
through it."""
RELOAD_INTERVAL = 60
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
self._reload_interval = timedelta(
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._try_updating_schedule()
task_logger.info(
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
)
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
@@ -59,35 +44,36 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
task_logger.debug("Reload interval reached, initiating task update")
logger.info("Reload interval reached, initiating task update")
try:
self._try_updating_schedule()
except (AttributeError, KeyError):
task_logger.exception("Failed to process task configuration")
except Exception:
task_logger.exception("Unexpected error updating tasks")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
self._last_reload = now
logger.info("Task update completed, reset reload timer")
return retval
def _generate_schedule(
self, tenant_ids: list[str] | list[None], beat_multiplier: float
self, tenant_ids: list[str] | list[None]
) -> dict[str, dict[str, Any]]:
"""Given a list of tenant id's, generates a new beat schedule for celery."""
logger.info("Fetching tasks to schedule")
new_schedule: dict[str, dict[str, Any]] = {}
if MULTI_TENANT:
# cloud tasks are system wide and thus only need to be on the beat schedule
# once for all tenants
# cloud tasks only need the single task beat across all tenants
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule",
"get_cloud_tasks_to_schedule",
)
cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
beat_multiplier
)
cloud_tasks_to_schedule: list[
dict[str, Any]
] = get_cloud_tasks_to_schedule()
for task in cloud_tasks_to_schedule:
task_name = task["name"]
cloud_task = {
@@ -96,14 +82,11 @@ class DynamicTenantScheduler(PersistentScheduler):
"kwargs": task.get("kwargs", {}),
}
if options := task.get("options"):
task_logger.debug(f"Adding options to task {task_name}: {options}")
logger.debug(f"Adding options to task {task_name}: {options}")
cloud_task["options"] = options
new_schedule[task_name] = cloud_task
# regular task beats are multiplied across all tenants
# note that currently this just schedules for a single tenant in self hosted
# and doesn't do anything in the cloud because it's much more scalable
# to schedule a single cloud beat task to dispatch per tenant tasks.
get_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
@@ -112,7 +95,7 @@ class DynamicTenantScheduler(PersistentScheduler):
for tenant_id in tenant_ids:
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
task_logger.debug(
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
@@ -121,14 +104,14 @@ class DynamicTenantScheduler(PersistentScheduler):
task_name = task["name"]
tenant_task_name = f"{task['name']}-{tenant_id}"
task_logger.debug(f"Creating task configuration for {tenant_task_name}")
logger.debug(f"Creating task configuration for {tenant_task_name}")
tenant_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
task_logger.debug(
logger.debug(
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
@@ -138,57 +121,44 @@ class DynamicTenantScheduler(PersistentScheduler):
def _try_updating_schedule(self) -> None:
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
logger.info("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
task_logger.debug(f"Found {len(tenant_ids)} IDs")
logger.info(f"Found {len(tenant_ids)} IDs")
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
# there are no more per tenant beat tasks, so comment this out
# NOTE: we may not actualy need this scheduler any more and should
# test reverting to a regular beat schedule implementation
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
# current_tenants = set()
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
# if the schedule or beat multiplier has changed, update
while True:
if beat_multiplier != self.last_beat_multiplier:
do_update = True
break
# if "_" in task_name:
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# # -> "12345678-abcd-efgh-ijkl-12345678"
# current_tenants.add(task_name.split("_")[-1])
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
if not DynamicTenantScheduler._compare_schedules(
current_schedule, new_schedule
):
do_update = True
break
# for tenant_id in tenant_ids:
# if tenant_id not in current_tenants:
# logger.info(f"Processing new tenant: {tenant_id}")
break
new_schedule = self._generate_schedule(tenant_ids)
if not do_update:
# exit early if nothing changed
task_logger.info(
f"_try_updating_schedule - Schedule unchanged: "
f"tasks={len(new_schedule)} "
f"beat_multiplier={beat_multiplier}"
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
logger.info(
"_try_updating_schedule: Current schedule is up to date, no changes needed"
)
return
# schedule needs updating
task_logger.debug(
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_schedule),
@@ -215,19 +185,11 @@ class DynamicTenantScheduler(PersistentScheduler):
# Ensure changes are persisted
self.sync()
task_logger.info(
f"_try_updating_schedule - Schedule updated: "
f"prev_num_tasks={len(current_schedule)} "
f"prev_beat_multiplier={self.last_beat_multiplier} "
f"tasks={len(new_schedule)} "
f"beat_multiplier={beat_multiplier}"
)
self.last_beat_multiplier = beat_multiplier
logger.info("_try_updating_schedule: Schedule updated successfully")
@staticmethod
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
"""Compare schedules by task name only to determine if an update is needed.
"""Compare schedules to determine if an update is needed.
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
@@ -239,7 +201,7 @@ class DynamicTenantScheduler(PersistentScheduler):
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
task_logger.info("beat_init signal received.")
logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)

View File

@@ -84,10 +84,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
@@ -144,6 +142,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(OnyxRedisConstants.ACTIVE_FENCES)

View File

@@ -1,4 +1,3 @@
import copy
from datetime import timedelta
from typing import Any
@@ -19,203 +18,242 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 4
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
beat_task_templates.extend(
[
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
beat_task_templates.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
cloud_task: dict[str, Any] = {}
# constant options for cloud beat task generators
task_schedule: timedelta = task["schedule"]
cloud_task["schedule"] = task_schedule
cloud_task["options"] = {}
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
# settings dependent on the original task
cloud_task["name"] = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_{task['name']}"
cloud_task["task"] = OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR
cloud_task["kwargs"] = {}
cloud_task["kwargs"]["task_name"] = task["task"]
optional_fields = ["queue", "priority", "expires"]
for field in optional_fields:
if field in task["options"]:
cloud_task["kwargs"][field] = task["options"][field]
return cloud_task
# tasks that only run in the cloud and are system wide
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
# by the DynamicTenantScheduler as system wide task and not a per tenant task
beat_cloud_tasks: list[dict] = [
# tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
# cloud specific tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-alembic",
"task": OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
"schedule": timedelta(hours=1),
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
# remaining tasks are cloud generators for per tenant tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-queues",
"task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES,
"schedule": timedelta(seconds=30),
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.LOW,
},
},
]
# tasks that only run self hosted
if LLM_MODEL_UPDATE_API_URL:
cloud_tasks_to_schedule.append(
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
), # Check every hour
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"priority": OnyxCeleryPriority.LOW,
},
}
)
# tasks that run in either self-hosted on cloud
tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
tasks_to_schedule.extend(
[
{
"name": "monitor-celery-queues",
"task": OnyxCeleryTask.MONITOR_CELERY_QUEUES,
"schedule": timedelta(seconds=10),
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=15),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)
tasks_to_schedule.extend(beat_task_templates)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
) -> list[dict[str, Any]]:
"""
beat_tasks: system wide tasks that can be sent as is
beat_templates: task templates that will be transformed into per tenant tasks via
the cloud_beat_task_generator
beat_multiplier: a multiplier that can be applied on top of the task schedule
to speed up or slow down the task generation rate. useful in production.
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
from incoming templates.
"""
if beat_multiplier <= 0:
raise ValueError("beat_multiplier must be positive!")
cloud_tasks: list[dict] = []
# generate our tenant aware cloud tasks from the templates
for beat_template in beat_templates:
cloud_task = make_cloud_generator_task(beat_template)
cloud_tasks.append(cloud_task)
# factor in the cloud multiplier for the above
for cloud_task in cloud_tasks:
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
# add the fixed cloud/system beat tasks. No multiplier for these.
cloud_tasks.extend(copy.deepcopy(beat_tasks))
return cloud_tasks
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
return generate_cloud_tasks(beat_cloud_tasks, beat_task_templates, beat_multiplier)
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@@ -1,14 +1,10 @@
import traceback
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -16,35 +12,18 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
class TaskDependencyError(RuntimeError):
@@ -63,7 +42,6 @@ def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -99,18 +77,6 @@ def check_for_connector_deletion_task(
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -220,7 +186,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
sync_type=SyncType.CONNECTOR_DELETION,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
pass
except TaskDependencyError:
redis_connector.delete.set_fence(None)
@@ -246,158 +212,3 @@ def try_generate_document_cc_pair_cleanup_tasks(
redis_connector.delete.set_fence(fence_payload)
return tasks_generated
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
)
return
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"onyx.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.reset()

View File

@@ -175,24 +175,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -246,15 +228,12 @@ def try_creating_permissions_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
# set a basic fence to start
redis_connector.permissions.set_active()
@@ -278,10 +257,11 @@ def try_creating_permissions_sync_task(
)
# fill in the celery task id
redis_connector.permissions.set_active()
payload.celery_task_id = result.id
redis_connector.permissions.set_fence(payload)
payload_id = payload.id
payload_id = payload.celery_task_id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
@@ -310,8 +290,6 @@ def connector_permission_sync_generator_task(
This task assumes that the task has already been properly fenced
"""
payload_id: str | None = None
LoggerContextVars.reset()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
@@ -354,12 +332,9 @@ def connector_permission_sync_generator_task(
sleep(1)
continue
payload_id = payload.id
logger.info(
f"connector_permission_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.permissions.fence_key} "
f"payload_id={payload.id}"
f"fence={redis_connector.permissions.fence_key}"
)
break
@@ -367,7 +342,6 @@ def connector_permission_sync_generator_task(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
@@ -439,9 +413,7 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
@@ -451,10 +423,6 @@ def connector_permission_sync_generator_task(
if lock.owned():
lock.release()
task_logger.info(
f"Permission sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
)
@shared_task(
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
@@ -478,15 +446,14 @@ def update_external_document_permissions_task(
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
continue_on_error=True,
)
# Then upsert the document's external permissions
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
@@ -510,11 +477,11 @@ def update_external_document_permissions_task(
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
)
except Exception:
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} doc_id={doc_id}"
f"connector_id={connector_id} "
f"doc_id={doc_id}"
)
return False
@@ -692,7 +659,7 @@ def validate_permission_sync_fence(
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
# we're only active if tasks_scanned > 0 and tasks_not_in_celery == 0
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.permissions.set_active()
return
@@ -713,8 +680,7 @@ def validate_permission_sync_fence(
"validate_permission_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key} "
f"payload_id={payload.id}"
f"fence={fence_key}"
)
redis_connector.permissions.reset()
@@ -775,7 +741,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
raise
"""Monitoring CCPair permissions utils"""
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
def monitor_ccpair_permissions_taskset(

View File

@@ -2,17 +2,15 @@ import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
@@ -26,17 +24,15 @@ from ee.onyx.external_permissions.sync_params import (
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.error_logging import emit_background_error
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
@@ -53,8 +49,7 @@ from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -72,26 +67,18 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
task_logger.error(
f"Recieved non-sync CC Pair {cc_pair.id} for external "
f"group sync. Actual access type: {cc_pair.access_type}"
)
return False
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
task_logger.debug(
f"Skipping group sync for CC Pair {cc_pair.id} - "
f"CC Pair is being deleted"
)
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
task_logger.debug(
f"Skipping group sync for CC Pair {cc_pair.id} - "
f"no group sync function for {cc_pair.connector.source}"
)
return False
# If the last sync is None, it has never been run so we run the sync
@@ -120,11 +107,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -133,9 +120,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.warning(
f"Failed to acquire beat lock for external group sync: {tenant_id}"
)
return None
try:
@@ -165,32 +149,30 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
payload_id = try_creating_external_group_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not payload_id:
if not tasks_created:
continue
task_logger.info(
f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}"
)
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# clear fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating external group sync fences"
)
# lock_beat.reacquire()
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# validate_external_group_sync_fences(
# tenant_id, self.app, r, r_celery, lock_beat
# )
# except Exception:
# task_logger.exception(
# "Exception while validating external group sync fences"
# )
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -209,46 +191,35 @@ def try_creating_external_group_sync_task(
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> str | None:
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
payload_id: str | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
logger.warning(
f"Skipping external group sync for CC Pair {cc_pair_id} - already running."
)
return None
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# Signal active before creating fence
redis_connector.external_group_sync.set_active()
payload = RedisConnectorExternalGroupSyncPayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.external_group_sync.set_fence(payload)
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
@@ -263,17 +234,27 @@ def try_creating_external_group_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
payload.celery_task_id = result.id
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return payload_id
return 1
@shared_task(
@@ -304,26 +285,22 @@ def connector_external_group_sync_generator_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
msg = (
raise ValueError(
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
if not redis_connector.external_group_sync.fenced: # The fence must exist
msg = (
raise ValueError(
f"connector_external_group_sync_generator_task - fence not found: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
payload = redis_connector.external_group_sync.payload # The payload must exist
if not payload:
msg = "connector_external_group_sync_generator_task: payload invalid or not found"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
raise ValueError(
"connector_external_group_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
@@ -335,8 +312,7 @@ def connector_external_group_sync_generator_task(
logger.info(
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.external_group_sync.fence_key} "
f"payload_id={payload.id}"
f"fence={redis_connector.external_group_sync.fence_key}"
)
break
@@ -348,9 +324,9 @@ def connector_external_group_sync_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
msg = f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
task_logger.error(msg)
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
@@ -371,9 +347,9 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
@@ -404,9 +380,9 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
@@ -425,41 +401,32 @@ def connector_external_group_sync_generator_task(
if lock.owned():
lock.release()
task_logger.info(
f"External group sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
)
def validate_external_group_sync_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_tasks = celery_get_unacked_task_ids(
reserved_sync_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
# validate all existing external group sync tasks
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorExternalGroupSync.FENCE_PREFIX):
continue
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_tasks,
r_celery,
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_sync_tasks,
r_celery,
db_session,
)
return
@@ -468,6 +435,7 @@ def validate_external_group_sync_fence(
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
@@ -496,11 +464,9 @@ def validate_external_group_sync_fence(
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
msg = (
task_logger.warning(
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
)
emit_background_error(msg)
task_logger.error(msg)
return
cc_pair_id = int(cc_pair_id_str)
@@ -512,28 +478,26 @@ def validate_external_group_sync_fence(
if not redis_connector.external_group_sync.fenced:
return
try:
payload = redis_connector.external_group_sync.payload
except ValidationError:
msg = (
"validate_external_group_sync_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
task_logger.exception(msg)
emit_background_error(msg, cc_pair_id=cc_pair_id)
redis_connector.external_group_sync.reset()
return
payload = redis_connector.external_group_sync.payload
if not payload:
return
if not payload.celery_task_id:
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
# if redis_connector_index.active():
# return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
"validate_external_group_sync_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return
# OK, there's actually something for us to validate
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
@@ -559,15 +523,11 @@ def validate_external_group_sync_fence(
# return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
emit_background_error(
message=(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key} "
f"payload_id={payload.id}"
),
cc_pair_id=cc_pair_id,
logger.warning(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.external_group_sync.reset()

View File

@@ -6,18 +6,13 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import Any
from typing import cast
import sentry_sdk
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
@@ -35,7 +30,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
@@ -43,7 +37,6 @@ from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
@@ -54,12 +47,9 @@ from onyx.db.swap_index import check_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -70,150 +60,6 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"Connector indexing: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
# the task is still setting up
return
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
task_state = result.state
if (
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
)
task_logger.warning(msg)
try:
index_attempt = get_index_attempt(
db_session, payload.index_attempt_id
)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
except Exception:
task_logger.exception(
"Connector indexing - Transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
redis_connector_index.reset()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
soft_time_limit=300,
@@ -245,25 +91,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
try:
locked = True
# SPECIAL 0/3: sync lookup table for active fences
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
# build a lookup table of existing fences
# this is just a migration concern and should be unnecessary once
# lookup tables are rolled out
for key_bytes in redis_client_replica.scan_iter(
count=SCAN_ITER_COUNT_DEFAULT
):
if is_fence(key_bytes) and not redis_client.sismember(
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
):
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
# 1/3: KICKOFF
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
@@ -370,8 +197,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
lock_beat.reacquire()
# 2/3: VALIDATE
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
@@ -411,26 +236,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
task_logger.exception("Exception while validating indexing fences")
redis_client.set(OnyxRedisSignals.BLOCK_VALIDATE_INDEXING_FENCES, 1, ex=60)
# 3/3: FINALIZE
lock_beat.reacquire()
keys = cast(
set[Any], redis_client_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)
)
for key in keys:
key_bytes = cast(bytes, key)
if not redis_client.exists(key_bytes):
redis_client.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(
tenant_id, key_bytes, redis_client_replica, db_session
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -618,8 +423,8 @@ def connector_indexing_task(
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector,
redis_connector_index,
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
r,
)

View File

@@ -99,16 +99,16 @@ class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_connector_index: RedisConnectorIndex,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_connector: RedisConnector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
@@ -120,7 +120,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
if self.redis_client.exists(self.stop_key):
return True
return False
@@ -143,8 +143,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
# self.last_parent_check = now
try:
self.redis_connector.prune.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
@@ -167,9 +165,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
redis_lock_dump(self.redis_lock, self.redis_client)
raise
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
self.redis_client.incrby(self.generator_progress_key, amount)
def validate_indexing_fence(

View File

@@ -17,8 +17,7 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
@@ -421,7 +420,6 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
- Throughput (docs/min) (only if success)
- Raw start/end times for each sync
"""
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records that ended in the last hour
@@ -589,10 +587,6 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
else:
# Only user groups and document set sync records have
# an associated entity we can use for latency metrics
continue
if entity is None:
task_logger.error(
@@ -723,7 +717,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
@shared_task(
name=OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
@@ -783,7 +777,7 @@ def cloud_check_alembic() -> bool | None:
tenant_to_revision[tenant_id] = result_scalar
except Exception:
task_logger.error(f"Tenant {tenant_id} has no revision!")
task_logger.warning(f"Tenant {tenant_id} has no revision!")
tenant_to_revision[tenant_id] = ALEMBIC_NULL_REVISION
# get the total count of each revision
@@ -853,55 +847,3 @@ def cloud_check_alembic() -> bool | None:
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True
@shared_task(
name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, ignore_result=True, bind=True
)
def cloud_monitor_celery_queues(
self: Task,
) -> None:
return monitor_celery_queues_helper(self)
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
return monitor_celery_queues_helper(self)
def monitor_celery_queues_helper(
task: Task,
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)

View File

@@ -1,39 +1,28 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
@@ -46,15 +35,10 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -109,8 +93,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -122,68 +104,32 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
return None
try:
# the entire task needs to run frequently in order to finalize pruning
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
cc_pair_ids: list[int] = []
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue
if not _is_pruning_due(cc_pair):
continue
if not _is_pruning_due(cc_pair):
continue
tasks_created = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not tasks_created:
continue
payload_id = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
continue
task_logger.info(
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
)
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=3600)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES):
# clear any permission fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -203,7 +149,7 @@ def try_creating_prune_generator_task(
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> str | None:
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -222,7 +168,7 @@ def try_creating_prune_generator_task(
# we need to serialize starting pruning since it can be triggered either via
# celery beat or manually (API call)
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
timeout=LOCK_TIMEOUT,
)
@@ -254,30 +200,7 @@ def try_creating_prune_generator_task(
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# signal active before the fence is set
redis_connector.prune.set_active()
# set a basic fence to start
payload = RedisConnectorPrunePayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.prune.set_fence(payload)
result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair.id,
@@ -290,11 +213,16 @@ def try_creating_prune_generator_task(
priority=OnyxCeleryPriority.LOW,
)
# fill in the celery task id
payload.celery_task_id = result.id
redis_connector.prune.set_fence(payload)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
payload_id = payload.id
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -302,7 +230,7 @@ def try_creating_prune_generator_task(
if lock.owned():
lock.release()
return payload_id
return 1
@shared_task(
@@ -324,8 +252,6 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
payload_id: str | None = None
LoggerContextVars.reset()
pruning_ctx_dict = pruning_ctx.get()
@@ -339,46 +265,6 @@ def connector_pruning_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_prune_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.prune.fence_key}"
)
if not redis_connector.prune.fenced: # The fence must exist
raise ValueError(
f"connector_prune_generator_task - fence not found: "
f"fence={redis_connector.prune.fence_key}"
)
payload = redis_connector.prune.payload # The payload must exist
if not payload:
raise ValueError(
"connector_prune_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_prune_generator_task - Waiting for fence: "
f"fence={redis_connector.prune.fence_key}"
)
time.sleep(1)
continue
payload_id = payload.id
logger.info(
f"connector_prune_generator_task - Fence found, continuing...: "
f"fence={redis_connector.prune.fence_key} "
f"payload_id={payload.id}"
)
break
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
@@ -408,18 +294,6 @@ def connector_pruning_generator_task(
)
return
payload = redis_connector.prune.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
new_payload = RedisConnectorPrunePayload(
id=payload.id,
submitted=payload.submitted,
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.prune.set_fence(new_payload)
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
@@ -433,13 +307,10 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
search_settings = get_current_search_settings(db_session)
redis_connector_index = redis_connector.new_index(search_settings.id)
callback = IndexingCallback(
0,
redis_connector,
redis_connector_index,
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
)
@@ -486,9 +357,7 @@ def connector_pruning_generator_task(
redis_connector.prune.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(
f"Pruning exceptioned: cc_pair={cc_pair_id} "
f"connector={connector_id} "
f"payload_id={payload_id}"
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.reset()
@@ -497,12 +366,10 @@ def connector_pruning_generator_task(
if lock.owned():
lock.release()
task_logger.info(
f"Pruning generator finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
"""Monitoring pruning utils"""
"""Monitoring pruning utils, called in monitor_vespa_sync"""
def monitor_ccpair_pruning_taskset(
@@ -548,184 +415,4 @@ def monitor_ccpair_pruning_taskset(
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(None)
def validate_pruning_fences(
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
return
# the queue for a single pruning generator task
reserved_generator_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
)
# the queue for a reasonably large set of lightweight deletion tasks
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
continue
validate_pruning_fence(
tenant_id,
key_bytes,
reserved_generator_tasks,
queued_upsert_tasks,
r,
r_celery,
)
lock_beat.reacquire()
return
def validate_pruning_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],
r: Redis,
r_celery: Redis,
) -> None:
"""See validate_indexing_fence for an overall idea of validation flows.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_pruning_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.prune.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.prune.payload
except ValidationError:
task_logger.exception(
"validate_pruning_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.prune.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
# either the generator task must be in flight or its subtasks must be
found = celery_find_task(
payload.celery_task_id,
OnyxCeleryQueues.CONNECTOR_PRUNING,
r_celery,
)
if found:
# the celery task exists in the redis queue
redis_connector.prune.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within a worker
redis_connector.prune.set_active()
return
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own pruning taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.prune.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
if member_str in reserved_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_pruning_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.prune.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.prune.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_pruning_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key} "
f"payload_id={payload.id}"
)
redis_connector.prune.reset()
return
redis_connector.prune.set_fence(False)

View File

@@ -8,7 +8,6 @@ from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
@@ -253,11 +252,7 @@ def cloud_beat_task_generator(
try:
tenant_ids = get_all_tenant_ids()
gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
@@ -275,7 +270,6 @@ def cloud_beat_task_generator(
queue=queue,
priority=priority,
expires=expires,
ignore_result=True,
)
except SoftTimeLimitExceeded:
task_logger.info(

View File

@@ -1,5 +1,9 @@
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import Any
from typing import cast
@@ -9,6 +13,8 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -16,27 +22,47 @@ from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
monitor_ccpair_permissions_taskset,
)
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import count_documents_by_needs_sync
from onyx.db.document import get_document
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import delete_document_set
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings
@@ -46,14 +72,20 @@ from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -62,6 +94,7 @@ from onyx.utils.variable_functionality import (
)
from onyx.utils.variable_functionality import global_version
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -81,7 +114,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -93,7 +125,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
return None
try:
# 1/3: KICKOFF
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
@@ -120,8 +151,9 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
# endregion
# check if any user groups are not synced
lock_beat.reacquire()
if global_version.is_ee_version():
lock_beat.reacquire()
try:
fetch_user_groups = fetch_versioned_implementation(
"onyx.db.user_group", "fetch_user_groups"
@@ -147,35 +179,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
# 2/3: VALIDATE: TODO
# 3/3: FINALIZE
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -336,15 +339,11 @@ def try_generate_document_set_sync_tasks(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
return tasks_generated
@@ -412,15 +411,11 @@ def try_generate_user_group_sync_tasks(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
@@ -503,6 +498,475 @@ def monitor_document_set_taskset(
rds.reset()
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
)
return
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"onyx.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.reset()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"Connector indexing: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
# the task is still setting up
return
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
task_state = result.state
if (
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
)
task_logger.warning(msg)
try:
index_attempt = get_index_attempt(
db_session, payload.index_attempt_id
)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
except Exception:
task_logger.exception(
"Connector indexing - Transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
redis_connector_index.reset()
@shared_task(
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
ignore_result=True,
soft_time_limit=300,
bind=True,
)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes various long running tasks.
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
now. Should change that at some point.
It scans for fence values and then gets the counts of any associated tasksets.
For many tasks, the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
# Replica usage notes
#
# False negatives are OK. (aka fail to to see a key that exists on the master).
# We simply skip the monitoring work and it will be caught on the next pass.
#
# False positives are not OK, and are possible if we clear a fence on the master and
# then read from the replica. In this case, monitoring work could be done on a fence
# that no longer exists. To avoid this, we scan from the replica, but double check
# the result on the master.
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return None
try:
# print current queue lengths
time.monotonic()
# we don't need every tenant polling redis for this info.
if not MULTI_TENANT or random.randint(1, 10) == 10:
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
n_deletion = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
n_pruning = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)
# we want to run this less frequently than the overall task
if not r.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
# build a lookup table of existing fences
# this is just a migration concern and should be unnecessary once
# lookup tables are rolled out
for key_bytes in r_replica.scan_iter(count=SCAN_ITER_COUNT_DEFAULT):
if is_fence(key_bytes) and not r.sismember(
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
):
logger.warning(f"Adding {key_bytes} to the lookup table.")
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
r.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
keys = cast(set[Any], r.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
elif key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
else:
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
return False
except Exception:
task_logger.exception("monitor_vespa_sync exceptioned.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"monitor_vespa_sync - Lock not owned on completion: "
f"tenant={tenant_id}"
# f"timings={timings}"
)
redis_lock_dump(lock_beat, r)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True
@shared_task(
name=OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
bind=True,
@@ -600,3 +1064,23 @@ def vespa_metadata_sync_task(
self.retry(exc=e, countdown=countdown)
return True
def is_fence(key_bytes: bytes) -> bool:
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
return True
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
return True
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
return True
return False

View File

@@ -1,13 +0,0 @@
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_tenant
def emit_background_error(
message: str,
cc_pair_id: int | None = None,
) -> None:
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -140,7 +140,6 @@ def build_citations_user_message(
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str = "",
context_type: str = "context documents",
) -> HumanMessage:
multilingual_expansion = get_multilingual_expansion()
task_prompt_with_reminder = build_task_prompt_reminders(
@@ -157,7 +156,6 @@ def build_citations_user_message(
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
user_prompt = CITATIONS_PROMPT.format(
context_type=context_type,
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
@@ -167,7 +165,6 @@ def build_citations_user_message(
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
context_type=context_type,
task_prompt=task_prompt_with_reminder,
user_query=query,
history_block=history_block,

View File

@@ -107,9 +107,9 @@ CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
@@ -263,11 +263,6 @@ class PostgresAdvisoryLocks(Enum):
class OnyxCeleryQueues:
# "celery" is the default queue defined by celery and also the queue
# we are running in the primary worker to run system tasks
# Tasks running in this queue should be designed specifically to run quickly
PRIMARY = "celery"
# Light queue
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
@@ -298,6 +293,7 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
@@ -323,8 +319,6 @@ class OnyxRedisSignals:
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
"signal:block_validate_permission_sync_fences"
)
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
@@ -346,18 +340,12 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
# the tenant id we use for system level redis operations
ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
class OnyxCeleryTask:
DEFAULT = "celery"
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
CLOUD_MONITOR_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_alembic"
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -367,8 +355,8 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (

View File

@@ -65,25 +65,10 @@ class AirtableConnector(LoadConnector):
base_id: str,
table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
view_id: str | None = None,
share_id: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
"""Initialize an AirtableConnector.
Args:
base_id: The ID of the Airtable base to connect to
table_name_or_id: The name or ID of the table to index
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
view_id: Optional ID of a specific view to use
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
batch_size: Number of records to process in each batch
"""
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.view_id = view_id
self.share_id = share_id
self.batch_size = batch_size
self._airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
@@ -100,39 +85,6 @@ class AirtableConnector(LoadConnector):
raise AirtableClientNotSetUpError()
return self._airtable_client
@classmethod
def _get_record_url(
cls,
base_id: str,
table_id: str,
record_id: str,
share_id: str | None,
view_id: str | None,
field_id: str | None = None,
attachment_id: str | None = None,
) -> str:
"""Constructs the URL for a record, optionally including field and attachment IDs
Full possible structure is:
https://airtable.com/BASE_ID/SHARE_ID/TABLE_ID/VIEW_ID/RECORD_ID/FIELD_ID/ATTACHMENT_ID
"""
# If we have a shared link, use that view for better UX
if share_id:
base_url = f"https://airtable.com/{base_id}/{share_id}/{table_id}"
else:
base_url = f"https://airtable.com/{base_id}/{table_id}"
if view_id:
base_url = f"{base_url}/{view_id}"
base_url = f"{base_url}/{record_id}"
if field_id and attachment_id:
return f"{base_url}/{field_id}/{attachment_id}?blocks=hide"
return base_url
def _extract_field_values(
self,
field_id: str,
@@ -158,10 +110,8 @@ class AirtableConnector(LoadConnector):
if field_type == "multipleRecordLinks":
return []
# Get the base URL for this record
default_link = self._get_record_url(
base_id, table_id, record_id, self.share_id, self.view_id or view_id
)
# default link to use for non-attachment fields
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
if field_type == "multipleAttachments":
attachment_texts: list[tuple[str, str]] = []
@@ -215,16 +165,17 @@ class AirtableConnector(LoadConnector):
extension=file_ext,
)
if attachment_text:
# Use the helper method to construct attachment URLs
attachment_link = self._get_record_url(
base_id,
table_id,
record_id,
self.share_id,
self.view_id or view_id,
field_id,
attachment_id,
)
# slightly nicer loading experience if we can specify the view ID
if view_id:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
else:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
attachment_texts.append(
(f"{filename}:\n{attachment_text}", attachment_link)
)

View File

@@ -27,7 +27,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -320,7 +319,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
@@ -388,12 +386,4 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list

View File

@@ -30,7 +30,6 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -322,7 +321,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
@@ -345,15 +343,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
if doc_batch:
yield doc_batch
@@ -379,10 +368,9 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._fetch_slim_threads(start, end, callback=callback)
yield from self._fetch_slim_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e

View File

@@ -42,7 +42,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -565,7 +564,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
@@ -578,26 +576,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(
start, end, callback=callback
)
yield from self._extract_slim_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
SecondsSinceUnixEpoch = float
@@ -64,7 +63,6 @@ class SlimConnector(BaseConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -91,7 +91,6 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
f"&response_type=code"
f"&scope=read"
f"&state={state}"
f"&prompt=consent" # prompts user for access; allows choosing workspace
)
@classmethod

View File

@@ -29,7 +29,6 @@ from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -145,8 +144,7 @@ def fetch_jira_issues_batch(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
semantic_identifier=issue.fields.summary,
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
@@ -247,7 +245,6 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"

View File

@@ -21,7 +21,6 @@ from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -177,7 +176,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:

View File

@@ -21,7 +21,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -243,7 +242,6 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):

View File

@@ -27,7 +27,6 @@ from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -99,7 +98,6 @@ def get_channel_messages(
channel: dict[str, Any],
oldest: str | None = None,
latest: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel"""
# join so that the bot can access messages
@@ -117,11 +115,6 @@ def get_channel_messages(
oldest=oldest,
latest=latest,
):
if callback:
if callback.should_stop():
raise RuntimeError("get_channel_messages: Stop signal detected")
callback.progress("get_channel_messages", 0)
yield cast(list[MessageType], result["messages"])
@@ -332,7 +325,6 @@ def _get_all_doc_ids(
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -350,7 +342,6 @@ def _get_all_doc_ids(
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
callback=callback,
)
message_ts_set: set[str] = set()
@@ -399,7 +390,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
@@ -408,7 +398,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
callback=callback,
)
def poll_source(

View File

@@ -39,6 +39,19 @@ def get_message_link(
return permalink
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
def logged_call(**kwargs: Any) -> SlackResponse:
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
result = call(**kwargs)
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
return result
return logged_call
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
@@ -114,14 +127,18 @@ def make_slack_api_rate_limited(
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(make_slack_api_rate_limited(call))
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)

View File

@@ -20,7 +20,6 @@ from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.retry_wrapper import retry_builder
@@ -406,7 +405,6 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
if self.content_type == "articles":

View File

@@ -51,7 +51,6 @@ class SearchPipeline:
user: User | None,
llm: LLM,
fast_llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: (
@@ -62,13 +61,10 @@ class SearchPipeline:
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
self.search_request = search_request
self.user = user
self.llm = llm
self.fast_llm = fast_llm
self.skip_query_analysis = skip_query_analysis
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
@@ -110,7 +106,6 @@ class SearchPipeline:
search_request=self.search_request,
user=self.user,
llm=self.llm,
skip_query_analysis=self.skip_query_analysis,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
@@ -165,12 +160,6 @@ class SearchPipeline:
that have a corresponding chunk.
This step should be fast for any document index implementation.
Current implementation timing is approximately broken down in timing as:
- 200 ms to get the embedding of the query
- 15 ms to get chunks from the document index
- possibly more to get additional surrounding chunks
- possibly more for query expansion (multilingual)
"""
if self._retrieved_sections is not None:
return self._retrieved_sections

View File

@@ -15,7 +15,6 @@ from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.document_index.document_index_utils import (
@@ -78,8 +77,7 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
@log_function_time(print_only=True)
def semantic_reranking(
query_str: str,
rerank_settings: RerankingDetails,
query: SearchQuery,
chunks: list[InferenceChunk],
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
@@ -90,9 +88,11 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
assert (
rerank_settings.rerank_model_name
), "Reranking flow cannot run without a specific model"
rerank_settings = query.rerank_settings
if not rerank_settings or not rerank_settings.rerank_model_name:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking flow should not be running")
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
@@ -107,7 +107,7 @@ def semantic_reranking(
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks_to_rerank
]
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
# Old logic to handle multiple cross-encoders preserved but not used
sim_scores = [numpy.array(sim_scores_floats)]
@@ -165,20 +165,8 @@ def semantic_reranking(
return list(ranked_chunks), list(ranked_indices)
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
- rerank_model_name is not None
- num_rerank is greater than 0
"""
if not rerank_settings:
return False
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
def rerank_sections(
query_str: str,
rerank_settings: RerankingDetails,
query: SearchQuery,
sections_to_rerank: list[InferenceSection],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceSection]:
@@ -193,13 +181,16 @@ def rerank_sections(
"""
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
if not query.rerank_settings:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
ranked_chunks, _ = semantic_reranking(
query_str=query_str,
rerank_settings=rerank_settings,
query=query,
chunks=chunks_to_rerank,
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
# However the ordering is still important
@@ -269,13 +260,16 @@ def search_postprocessing(
rerank_task_id = None
sections_yielded = False
if should_rerank(search_query.rerank_settings):
if (
search_query.rerank_settings
and search_query.rerank_settings.rerank_model_name
and search_query.rerank_settings.num_rerank > 0
):
post_processing_tasks.append(
FunctionCall(
rerank_sections,
(
search_query.query,
search_query.rerank_settings, # Cannot be None here
search_query,
retrieved_sections,
rerank_metrics_callback,
),

View File

@@ -50,11 +50,11 @@ def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
skip_query_analysis: bool,
db_session: Session,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
base_recency_decay: float = BASE_RECENCY_DECAY,
bypass_acl: bool = False,
skip_query_analysis: bool = False,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
) -> SearchQuery:
"""Logic is as follows:
Any global disables apply first
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (False, None)
else (None, None)
)
all_query_terms = query.split()

View File

@@ -1,10 +0,0 @@
from sqlalchemy.orm import Session
from onyx.db.models import BackgroundError
def create_background_error(
db_session: Session, message: str, cc_pair_id: int | None
) -> None:
db_session.add(BackgroundError(message=message, cc_pair_id=cc_pair_id))
db_session.commit()

View File

@@ -152,7 +152,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if not specified, all assistants are shown
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -483,10 +483,6 @@ class ConnectorCredentialPair(Base):
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
background_errors: Mapped[list["BackgroundError"]] = relationship(
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
)
class Document(Base):
__tablename__ = "document"
@@ -2119,31 +2115,6 @@ class StandardAnswer(Base):
)
class BackgroundError(Base):
"""Important background errors. Serves to:
1. Ensure that important logs are kept around and not lost on rotation/container restarts
2. A trail for high-signal events so that the debugger doesn't need to remember/know every
possible relevant log line.
"""
__tablename__ = "background_error"
id: Mapped[int] = mapped_column(primary_key=True)
message: Mapped[str] = mapped_column(String)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# option to link the error to a specific CC Pair
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), nullable=True
)
cc_pair: Mapped["ConnectorCredentialPair | None"] = relationship(
"ConnectorCredentialPair", back_populates="background_errors"
)
"""Tables related to Permission Sync"""

View File

@@ -204,14 +204,6 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
if user and user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
persona = upsert_persona(
persona_id=persona_id,
user=user,
@@ -236,7 +228,6 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -518,7 +509,6 @@ def upsert_persona(
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
existing_persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -599,23 +589,6 @@ def delete_old_default_personas(
db_session.commit()
def update_persona_is_default(
persona_id: int,
is_default: bool,
db_session: Session,
user: User | None = None,
) -> None:
persona = fetch_persona_by_id_for_user(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if not persona.is_public:
persona.is_public = True
persona.is_default_persona = is_default
db_session.commit()
def update_persona_visibility(
persona_id: int,
is_visible: bool,

View File

@@ -6,7 +6,6 @@ from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import ColumnElement
@@ -275,7 +274,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str], continue_on_error: bool = False
db_session: Session, emails: list[str]
) -> list[User]:
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
@@ -284,23 +283,10 @@ def batch_add_ext_perm_user_if_not_exists(
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))
try:
db_session.add_all(new_users)
db_session.commit()
except IntegrityError:
db_session.rollback()
if not continue_on_error:
raise
for user in new_users:
try:
db_session.add(user)
db_session.commit()
except IntegrityError:
db_session.rollback()
continue
# Fetch all users again to ensure we have the most up-to-date list
all_users, _ = _get_users_by_emails(db_session, lower_emails)
return all_users
db_session.add_all(new_users)
db_session.commit()
return found_users + new_users
def delete_user_from_db(

View File

@@ -17,7 +17,6 @@ from uuid import UUID
import httpx # type: ignore
import requests # type: ignore
from retry import retry
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import NUM_RETURNED_HITS
@@ -550,11 +549,6 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
@retry(
tries=3,
delay=1,
backoff=2,
)
def _update_single_chunk(
self,
doc_chunk_id: UUID,
@@ -565,7 +559,6 @@ class VespaIndex(DocumentIndex):
) -> None:
"""
Update a single "chunk" (document) in Vespa using its chunk ID.
Retries if we encounter transient HTTPStatusError (e.g., overload).
"""
update_dict: dict[str, dict] = {"fields": {}}
@@ -574,11 +567,13 @@ class VespaIndex(DocumentIndex):
update_dict["fields"][BOOST] = {"assign": fields.boost}
if fields.document_sets is not None:
# WeightedSet<string> needs a map { item: weight, ... }
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {document_set: 1 for document_set in fields.document_sets}
}
if fields.access is not None:
# Similar to above
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
}
@@ -590,10 +585,7 @@ class VespaIndex(DocumentIndex):
logger.error("Update request received but nothing to update.")
return
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
"?create=true"
)
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
try:
resp = http_client.put(
@@ -603,11 +595,8 @@ class VespaIndex(DocumentIndex):
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). "
f"Details: {e.response.text}"
)
# Re-raise so the @retry decorator will catch and retry
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
logger.error(error_message)
raise
def update_single(

View File

@@ -146,23 +146,6 @@ def _index_vespa_chunk(
title = document.get_title_for_document_index()
metadata_json = document.metadata
cleaned_metadata_json: dict[str, str | list[str]] = {}
for key, value in metadata_json.items():
cleaned_key = remove_invalid_unicode_chars(key)
if isinstance(value, list):
cleaned_metadata_json[cleaned_key] = [
remove_invalid_unicode_chars(item) for item in value
]
else:
cleaned_metadata_json[cleaned_key] = remove_invalid_unicode_chars(value)
metadata_list = document.get_metadata_str_attributes()
if metadata_list:
metadata_list = [
remove_invalid_unicode_chars(metadata) for metadata in metadata_list
]
vespa_document_fields = {
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
@@ -183,10 +166,10 @@ def _index_vespa_chunk(
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
SECTION_CONTINUATION: chunk.section_continuation,
LARGE_CHUNK_REFERENCE_IDS: chunk.large_chunk_reference_ids,
METADATA: json.dumps(cleaned_metadata_json),
METADATA: json.dumps(document.metadata),
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: metadata_list,
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),

View File

@@ -27,7 +27,6 @@ from langchain_core.prompt_values import PromptValue
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.configs.chat_configs import QA_TIMEOUT
from onyx.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
@@ -36,7 +35,6 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_is_reasoning_model
from onyx.server.utils import mask_string
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@@ -231,15 +229,15 @@ class DefaultMultiLLM(LLM):
def __init__(
self,
api_key: str | None,
timeout: int,
model_provider: str,
model_name: str,
timeout: int | None = None,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
max_output_tokens: int | None = None,
custom_llm_provider: str | None = None,
temperature: float | None = None,
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
@@ -247,16 +245,9 @@ class DefaultMultiLLM(LLM):
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
if timeout is None:
if model_is_reasoning_model(model_name):
self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow
else:
self._timeout = QA_TIMEOUT
self._temperature = GEN_AI_TEMPERATURE if temperature is None else temperature
self._model_provider = model_provider
self._model_version = model_name
self._temperature = temperature
self._api_key = api_key
self._deployment_name = deployment_name
self._api_base = api_base
@@ -396,14 +387,9 @@ class DefaultMultiLLM(LLM):
self._record_call(processed_prompt)
try:
print(
"model is",
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
)
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
# model choice
# model="openai/gpt-4",
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock

View File

@@ -2,6 +2,7 @@ from typing import Any
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.chat_configs import QA_TIMEOUT
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
@@ -87,8 +88,8 @@ def get_llms_for_persona(
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
@@ -137,7 +138,7 @@ def get_llm(
api_version: str | None = None,
custom_config: dict[str, str] | None = None,
temperature: float | None = None,
timeout: int | None = None,
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:

View File

@@ -29,11 +29,11 @@ OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o3-mini",
"o1-mini",
"o1",
"o1-preview",
"o1-2024-12-17",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"o1-preview",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",

View File

@@ -543,14 +543,3 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
f"Failed to get model object for {model_provider}/{model_name}"
)
return False
def model_is_reasoning_model(model_name: str) -> bool:
_REASONING_MODEL_NAMES = [
"o1",
"o1-mini",
"o3-mini",
"deepseek-reasoner",
"deepseek-r1",
]
return model_name.lower() in _REASONING_MODEL_NAMES

View File

@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
if not tokenizer:
logger.info(
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

View File

@@ -91,7 +91,7 @@ SAMPLE RESPONSE:
# similar to the chat flow, but with the option of including a
# "conversation history" block
CITATIONS_PROMPT = f"""
Refer to the following {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT}
Refer to the following context documents when responding to me.{DEFAULT_IGNORE_STATEMENT}
CONTEXT:
{GENERAL_SEP_PAT}
@@ -108,7 +108,7 @@ CONTEXT:
# NOTE: need to add the extra line about "getting right to the point" since the
# tool calling models from OpenAI tend to be more verbose
CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
Refer to the provided {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT} \
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
You should always get right to the point, and never use extraneous language.
{{history_block}}{{task_prompt}}

View File

@@ -120,7 +120,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
num_tasks_sent += 1

View File

@@ -132,7 +132,6 @@ class RedisConnectorDelete:
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
async_results.append(result)

View File

@@ -11,7 +11,6 @@ from redis.lock import Lock as RedisLock
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -50,7 +49,7 @@ class RedisConnectorPermissionSync:
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
@@ -81,8 +80,7 @@ class RedisConnectorPermissionSync:
def get_active_task_count(self) -> int:
"""Count of active permission sync tasks"""
count = 0
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
for _ in self.redis.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
@@ -196,7 +194,6 @@ class RedisConnectorPermissionSync:
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
ignore_result=True,
)
async_results.append(result)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
import redis
@@ -7,12 +8,10 @@ from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.constants import OnyxRedisConstants
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
@@ -38,12 +37,6 @@ class RedisConnectorExternalGroupSync:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -57,7 +50,6 @@ class RedisConnectorExternalGroupSync:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -74,8 +66,7 @@ class RedisConnectorExternalGroupSync:
def get_active_task_count(self) -> int:
"""Count of active external group syncing tasks"""
count = 0
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
for _ in self.redis.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
@@ -92,11 +83,10 @@ class RedisConnectorExternalGroupSync:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_raw = self.redis.get(self.fence_key)
if fence_raw is None:
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_bytes = cast(bytes, fence_raw)
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
cast(str, fence_str)
@@ -109,26 +99,10 @@ class RedisConnectorExternalGroupSync:
payload: RedisConnectorExternalGroupSyncPayload | None,
) -> None:
if not payload:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
@@ -164,8 +138,6 @@ class RedisConnectorExternalGroupSync:
pass
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@@ -180,9 +152,6 @@ class RedisConnectorExternalGroupSync:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorExternalGroupSync.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -1,16 +1,13 @@
import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -19,13 +16,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPrunePayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
class RedisConnectorPrune:
"""Manages interactions with redis for pruning tasks. Should only be accessed
through RedisConnector."""
@@ -46,12 +36,6 @@ class RedisConnectorPrune:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -65,7 +49,6 @@ class RedisConnectorPrune:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -82,10 +65,8 @@ class RedisConnectorPrune:
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
RedisConnectorPrune.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
for key in self.redis.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
count += 1
return count
@@ -97,44 +78,15 @@ class RedisConnectorPrune:
return False
@property
def payload(self) -> RedisConnectorPrunePayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(
self,
payload: RedisConnectorPrunePayload | None,
) -> None:
if not payload:
def set_fence(self, value: bool) -> None:
if not value:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.set(self.fence_key, 0)
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
@@ -202,7 +154,6 @@ class RedisConnectorPrune:
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
async_results.append(result)
@@ -211,7 +162,6 @@ class RedisConnectorPrune:
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@@ -226,9 +176,6 @@ class RedisConnectorPrune:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -1,29 +0,0 @@
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_usergroup import RedisUserGroup
def is_fence(key_bytes: bytes) -> bool:
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
return True
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
return True
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
return True
return False

View File

@@ -162,11 +162,6 @@ def load_personas_from_yaml(
else persona.get("is_visible")
),
db_session=db_session,
is_default_persona=(
existing_persona.is_default_persona
if existing_persona is not None
else persona.get("is_default_persona", False)
),
)

View File

@@ -41,7 +41,6 @@ personas:
icon_color: "#6FB1FF"
display_priority: 0
is_visible: true
is_default_persona: true
starter_messages:
- name: "Give me an overview of what's here"
message: "Sample some documents and tell me what you find."
@@ -67,7 +66,6 @@ personas:
icon_color: "#FF6F6F"
display_priority: 1
is_visible: true
is_default_persona: true
starter_messages:
- name: "Summarize a document"
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
@@ -93,7 +91,6 @@ personas:
icon_color: "#6FFF8D"
display_priority: 2
is_visible: false
is_default_persona: true
starter_messages:
- name: "Document Search"
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
@@ -120,7 +117,6 @@ personas:
image_generation: true
display_priority: 3
is_visible: true
is_default_persona: true
starter_messages:
- name: "Create visuals for a presentation"
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."

View File

@@ -22,8 +22,6 @@ from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id_for_user,
@@ -230,13 +228,6 @@ def update_cc_pair_status(
db_session.commit()
# this speeds up the start of indexing by firing the check immediately
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(tenant_id=tenant_id),
priority=OnyxCeleryPriority.HIGH,
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)
@@ -368,17 +359,15 @@ def prune_cc_pair(
f"credential={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_prune_generator_task(
tasks_created = try_creating_prune_generator_task(
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not payload_id:
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Pruning task creation failed.",
)
logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the pruning task.",
@@ -516,17 +505,15 @@ def sync_cc_pair_groups(
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
tasks_created = try_creating_external_group_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not payload_id:
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="External group sync task creation failed.",
)
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the external group sync task.",
@@ -553,14 +540,7 @@ def associate_credential_to_connector(
metadata: ConnectorCredentialPairMetadata,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[int]:
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
and create_connector_with_mock_credential should be combined.
The intent of this endpoint is to handle connectors that actually need credentials.
"""
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
@@ -583,18 +563,6 @@ def associate_credential_to_connector(
groups=metadata.groups,
)
# trigger indexing immediately
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
logger.info(
f"associate_credential_to_connector - running check_for_indexing: "
f"cc_pair={response.data}"
)
return response
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")

View File

@@ -804,14 +804,6 @@ def create_connector_with_mock_credential(
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse:
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
and associate_credential_to_connector should be combined.
The intent of this endpoint is to handle connectors that don't need credentials,
AKA web, file, etc ... but there isn't any reason a single endpoint couldn't
server this purpose.
"""
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
@@ -849,18 +841,6 @@ def create_connector_with_mock_credential(
groups=connector_data.groups,
)
# trigger indexing immediately
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
logger.info(
f"create_connector_with_mock_credential - running check_for_indexing: "
f"cc_pair={response.data}"
)
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
@@ -1025,8 +1005,6 @@ def connector_run_once(
kwargs={"tenant_id": tenant_id},
)
logger.info("connector_run_once - running check_for_indexing")
msg = f"Marked {num_triggers} index attempts with indexing triggers."
return StatusResponse(
success=True,

View File

@@ -179,10 +179,12 @@ def oauth_callback(
db_session=db_session,
)
# TODO: use a library for url handling
sep = "&" if "?" in desired_return_url else "?"
return CallbackResponse(
redirect_url=f"{desired_return_url}{sep}credentialId={credential.id}"
redirect_url=(
f"{desired_return_url}?credentialId={credential.id}"
if "?" not in desired_return_url
else f"{desired_return_url}&credentialId={credential.id}"
)
)

View File

@@ -6,15 +6,11 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
from onyx.db.document_set import fetch_all_document_sets_for_user
from onyx.db.document_set import insert_document_set
from onyx.db.document_set import mark_document_set_as_to_be_deleted
from onyx.db.document_set import update_document_set
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.server.features.document_set.models import CheckDocSetPublicRequest
@@ -33,7 +29,6 @@ def create_document_set(
document_set_creation_request: DocumentSetCreationRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> int:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
@@ -51,13 +46,6 @@ def create_document_set(
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
return document_set_db_model.id
@@ -66,7 +54,6 @@ def patch_document_set(
document_set_update_request: DocumentSetUpdateRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
@@ -85,19 +72,12 @@ def patch_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
@router.delete("/admin/document-set/{document_set_id}")
def delete_document_set(
document_set_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
try:
mark_document_set_as_to_be_deleted(
@@ -108,12 +88,6 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
"""Endpoints for non-admins"""

View File

@@ -32,7 +32,6 @@ from onyx.db.persona import get_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_all_personas_display_priority
from onyx.db.persona import update_persona_is_default
from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared_users
@@ -57,6 +56,7 @@ from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
logger = setup_logger()
@@ -72,10 +72,6 @@ class IsPublicRequest(BaseModel):
is_public: bool
class IsDefaultRequest(BaseModel):
is_default_persona: bool
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
@@ -110,25 +106,6 @@ def patch_user_presona_public_status(
raise HTTPException(status_code=403, detail=str(e))
@admin_router.patch("/{persona_id}/default")
def patch_persona_default_status(
persona_id: int,
is_default_request: IsDefaultRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_is_default(
persona_id=persona_id,
is_default=is_default_request.is_default_persona,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona default status")
raise HTTPException(status_code=403, detail=str(e))
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,

View File

@@ -76,7 +76,6 @@ def gpt_search(
user=None,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=True,
db_session=db_session,
).reranked_sections

View File

@@ -197,11 +197,6 @@ def create_deletion_attempt_for_connector_id(
kwargs={"tenant_id": tenant_id},
)
logger.info(
f"create_deletion_attempt_for_connector_id - running check_for_connector_deletion: "
f"cc_pair={cc_pair.id}"
)
if cc_pair.connector.source == DocumentSource.FILE:
connector = cc_pair.connector
file_store = get_default_file_store(db_session)

View File

@@ -34,7 +34,6 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
@@ -287,7 +286,7 @@ def bulk_invite_users(
detail=f"Invalid email address: {email} - {str(e)}",
)
if MULTI_TENANT and not DEV_MODE:
if MULTI_TENANT:
try:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None

View File

@@ -12,10 +12,10 @@ class PageType(str, Enum):
SEARCH = "search"
class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GATED_ACCESS = "gated_access"
ACTIVE = "active"
class GatingType(str, Enum):
FULL = "full" # Complete restriction of access to the product or service
PARTIAL = "partial" # Full access but warning (no credit card on file)
NONE = "none" # No restrictions, full access to all features
class Notification(BaseModel):
@@ -43,7 +43,7 @@ class Settings(BaseModel):
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
product_gating: GatingType = GatingType.NONE
anonymous_user_enabled: bool | None = None
pro_search_disabled: bool | None = None
auto_scroll: bool | None = None

View File

@@ -34,7 +34,7 @@ Now respond to the following:
""".strip()
class BaseTool(Tool[None]):
class BaseTool(Tool):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",

View File

@@ -1,14 +1,11 @@
from collections.abc import Callable
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
class ToolResponse(BaseModel):
@@ -60,15 +57,5 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
class Config:
arbitrary_types_allowed = True
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"

View File

@@ -1,9 +1,7 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -16,10 +14,7 @@ if TYPE_CHECKING:
from onyx.tools.models import ToolResponse
OVERRIDE_T = TypeVar("OVERRIDE_T")
class Tool(abc.ABC, Generic[OVERRIDE_T]):
class Tool(abc.ABC):
@property
@abc.abstractmethod
def name(self) -> str:
@@ -62,9 +57,7 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
"""Actual execution of the tool"""
@abc.abstractmethod
def run(
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
) -> Generator["ToolResponse", None, None]:
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
raise NotImplementedError
@abc.abstractmethod

View File

@@ -74,7 +74,6 @@ class CustomToolCallSummary(BaseModel):
tool_result: Any # The response data
# override_kwargs is not supported for custom tools
class CustomTool(BaseTool):
def __init__(
self,
@@ -236,9 +235,7 @@ class CustomTool(BaseTool):
"""Actual execution of the tool"""
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)
path_params = {}

View File

@@ -79,8 +79,7 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
class ImageGenerationTool(Tool):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation"
@@ -256,9 +255,7 @@ class ImageGenerationTool(Tool[None]):
"An error occurred during image generation. Please try again later."
)
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format

View File

@@ -106,8 +106,7 @@ def internet_search_response_to_search_docs(
]
# override_kwargs is not supported for internet search tools
class InternetSearchTool(Tool[None]):
class InternetSearchTool(Tool):
_NAME = "run_internet_search"
_DISPLAY_NAME = "Internet Search"
_DESCRIPTION = "Perform an internet search for up-to-date information."
@@ -243,9 +242,7 @@ class InternetSearchTool(Tool[None]):
],
)
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["internet_search_query"])
results = self._perform_search(query)
@@ -282,5 +279,4 @@ class InternetSearchTool(Tool[None]):
using_tool_calling_llm=using_tool_calling_llm,
answer_style_config=self.answer_style_config,
prompt_config=self.prompt_config,
context_type="internet search results",
)

View File

@@ -39,7 +39,6 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
@@ -78,7 +77,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
"""
class SearchTool(Tool[SearchToolOverrideKwargs]):
class SearchTool(Tool):
_NAME = "run_search"
_DISPLAY_NAME = "Search Tool"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
@@ -276,19 +275,14 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs["query"])
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
kwargs.get("retrieved_sections_callback"),
)
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@@ -330,7 +324,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
skip_query_analysis=skip_query_analysis,
bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,

View File

@@ -25,7 +25,6 @@ def build_next_prompt_for_search_like_tool(
using_tool_calling_llm: bool,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
context_type: str = "context documents",
) -> AnswerPromptBuilder:
if not using_tool_calling_llm:
final_context_docs_response = next(
@@ -59,7 +58,6 @@ def build_next_prompt_for_search_like_tool(
else False
),
history_message=prompt_builder.single_message_history or "",
context_type=context_type,
)
)

View File

@@ -86,10 +86,7 @@ def run_functions_in_parallel(
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
are the result_id of the FunctionCall and the values are the results of the call.
"""
results: dict[str, Any] = {}
if len(function_calls) == 0:
return results
results = {}
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {

View File

@@ -256,28 +256,16 @@ def get_documents_for_tenant_connector(
def search_for_document(
index_name: str,
document_id: str | None = None,
tenant_id: str | None = None,
max_hits: int | None = 10,
index_name: str, document_id: str, max_hits: int | None = 10
) -> List[Dict[str, Any]]:
yql_query = f"select * from sources {index_name}"
conditions = []
if document_id is not None:
conditions.append(f'document_id contains "{document_id}"')
if tenant_id is not None:
conditions.append(f'tenant_id contains "{tenant_id}"')
if conditions:
yql_query += " where " + " and ".join(conditions)
yql_query = (
f'select * from sources {index_name} where document_id contains "{document_id}"'
)
params: dict[str, Any] = {"yql": yql_query}
if max_hits is not None:
params["hits"] = max_hits
with get_vespa_http_client() as client:
response = client.get(f"{SEARCH_ENDPOINT}search/", params=params)
response = client.get(f"{SEARCH_ENDPOINT}/search/", params=params)
response.raise_for_status()
result = response.json()
documents = result.get("root", {}).get("children", [])
@@ -594,15 +582,8 @@ class VespaDebugging:
) -> None:
update_document(self.tenant_id, connector_id, doc_id, fields)
def delete_documents_for_tenant(self, count: int | None = None) -> None:
if not self.tenant_id:
raise Exception("Tenant ID is not set")
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
def search_for_document(
self, document_id: str | None = None, tenant_id: str | None = None
) -> List[Dict[str, Any]]:
return search_for_document(self.index_name, document_id, tenant_id)
def search_for_document(self, document_id: str) -> List[Dict[str, Any]]:
return search_for_document(self.index_name, document_id)
def delete_document(self, connector_id: int, doc_id: str) -> None:
# Delete a document.
@@ -619,147 +600,6 @@ class VespaDebugging:
get_document_acls(self.tenant_id, cc_pair_id, n)
def delete_where(
index_name: str,
selection: str,
cluster: str = "default",
bucket_space: str | None = None,
continuation: str | None = None,
time_chunk: str | None = None,
timeout: str | None = None,
tracelevel: int | None = None,
) -> None:
"""
Removes visited documents in `cluster` where the given selection
is true, using Vespa's 'delete where' endpoint.
:param index_name: Typically <namespace>/<document-type> from your schema
:param selection: The selection string, e.g., "true" or "foo contains 'bar'"
:param cluster: The name of the cluster where documents reside
:param bucket_space: e.g. 'global' or 'default'
:param continuation: For chunked visits
:param time_chunk: If you want to chunk the visit by time
:param timeout: e.g. '10s'
:param tracelevel: Increase for verbose logs
"""
# Using index_name of form <namespace>/<document-type>, e.g. "nomic_ai_nomic_embed_text_v1"
# This route ends with "/docid/" since the actual ID is not specified — we rely on "selection".
path = f"/document/v1/{index_name}/docid/"
params = {
"cluster": cluster,
"selection": selection,
}
# Optional parameters
if bucket_space is not None:
params["bucketSpace"] = bucket_space
if continuation is not None:
params["continuation"] = continuation
if time_chunk is not None:
params["timeChunk"] = time_chunk
if timeout is not None:
params["timeout"] = timeout
if tracelevel is not None:
params["tracelevel"] = tracelevel # type: ignore
with get_vespa_http_client() as client:
url = f"{VESPA_APPLICATION_ENDPOINT}{path}"
logger.info(f"Performing 'delete where' on {url} with selection={selection}...")
response = client.delete(url, params=params)
# (Optionally, you can keep fetching `continuation` from the JSON response
# if you have more documents to delete in chunks.)
response.raise_for_status() # will raise HTTPError if not 2xx
logger.info(f"Delete where completed with status: {response.status_code}")
print(f"Delete where completed with status: {response.status_code}")
def delete_documents_for_tenant(
index_name: str,
tenant_id: str,
route: str | None = None,
condition: str | None = None,
timeout: str | None = None,
tracelevel: int | None = None,
count: int | None = None,
) -> None:
"""
For the given tenant_id and index_name (often in the form <namespace>/<document-type>),
find documents via search_for_document, then delete them one at a time using Vespa's
/document/v1/<namespace>/<document-type>/docid/<document-id> endpoint.
:param index_name: Typically <namespace>/<document-type> from your schema
:param tenant_id: The tenant to match in your Vespa search
:param route: Optional route parameter for delete
:param condition: Optional conditional remove
:param timeout: e.g. '10s'
:param tracelevel: Increase for verbose logs
"""
deleted_count = 0
while True:
# Search for documents with the given tenant_id
docs = search_for_document(
index_name=index_name,
document_id=None,
tenant_id=tenant_id,
max_hits=100, # Fetch in batches of 100
)
if not docs:
logger.info("No more documents found to delete.")
break
with get_vespa_http_client() as client:
for doc in docs:
if count is not None and deleted_count >= count:
logger.info(f"Reached maximum delete limit of {count} documents.")
return
fields = doc.get("fields", {})
doc_id_value = fields.get("document_id") or fields.get("documentid")
tenant_id = fields.get("tenant_id")
if tenant_id != tenant_id:
raise Exception("Tenant ID mismatch")
if not doc_id_value:
logger.warning(
"Skipping a document that has no document_id in 'fields'."
)
continue
url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_id_value}"
params = {}
if condition:
params["condition"] = condition
if route:
params["route"] = route
if timeout:
params["timeout"] = timeout
if tracelevel is not None:
params["tracelevel"] = str(tracelevel)
response = client.delete(url, params=params)
if response.status_code == 200:
logger.info(f"Successfully deleted doc_id={doc_id_value}")
deleted_count += 1
else:
logger.error(
f"Failed to delete doc_id={doc_id_value}, "
f"status={response.status_code}, response={response.text}"
)
print(
f"Could not delete doc_id={doc_id_value}. "
f"Status={response.status_code}, response={response.text}"
)
raise Exception(
f"Could not delete doc_id={doc_id_value}. "
f"Status={response.status_code}, response={response.text}"
)
logger.info(f"Deleted {deleted_count} documents in total.")
def main() -> None:
parser = argparse.ArgumentParser(description="Vespa debugging tool")
parser.add_argument(
@@ -772,7 +612,6 @@ def main() -> None:
"update",
"delete",
"get_acls",
"delete-all-documents",
],
required=True,
help="Action to perform",
@@ -787,20 +626,11 @@ def main() -> None:
parser.add_argument(
"--fields", help="Fields to update, in JSON format (for update)"
)
parser.add_argument(
"--count",
type=int,
help="Maximum number of documents to delete (for delete-all-documents)",
)
args = parser.parse_args()
vespa_debug = VespaDebugging(args.tenant_id)
if args.action == "delete-all-documents":
if not args.tenant_id:
parser.error("--tenant-id is required for delete-all-documents action")
vespa_debug.delete_documents_for_tenant(count=args.count)
elif args.action == "config":
if args.action == "config":
vespa_debug.print_config()
elif args.action == "connect":
vespa_debug.check_connectivity()

Some files were not shown because too many files have changed in this diff Show More