Compare commits

..

2 Commits

Author SHA1 Message Date
hagen-danswer
b4abbbfe00 Update conftest.py 2025-01-08 14:01:28 -08:00
hagen-danswer
cf9caf79ff Fixing google drive texts 2025-01-08 13:27:59 -08:00
520 changed files with 13032 additions and 31348 deletions

View File

@@ -1,14 +1,11 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] [Optional] Override Linear Check

View File

@@ -67,7 +67,6 @@ jobs:
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -118,6 +118,6 @@ jobs:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

View File

@@ -60,8 +60,6 @@ jobs:
push: true
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -8,8 +8,6 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.17.0
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.7.0
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -46,7 +62,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.12.0
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -1,29 +0,0 @@
name: Ensure PR references Linear
on:
pull_request:
types: [opened, edited, reopened, synchronize]
jobs:
linear-check:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."
exit 0
fi
# Looking for a checked override: "[x] Override Linear Check"
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
echo "Override box is checked. Check passed."
exit 0
fi
# Otherwise, fail the run
echo "No Linear link or override found in the PR description."
exit 1

View File

@@ -39,12 +39,6 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -5,8 +5,6 @@
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Skip warm up for dev
SKIP_WARM_UP=True
# Always keep these on for Dev
# Logs all model prompts to stdout
@@ -29,7 +27,6 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
OPENAI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o

View File

@@ -28,7 +28,6 @@
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
@@ -52,8 +51,7 @@
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
"Celery beat"
],
"presentation": {
"group": "1",
@@ -271,31 +269,6 @@
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",

View File

@@ -17,10 +17,9 @@ Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. CD into web, run "npm i" followed by npm run dev.
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
7. Use the debug toolbar to step through code, inspect variables, etc.
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
## Features

View File

@@ -119,7 +119,7 @@ There are two editions of Onyx:
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:

View File

@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"

View File

@@ -1,29 +0,0 @@
"""add shortcut option for users
Revision ID: 027381bce97c
Revises: 6fc7886d665d
Create Date: 2025-01-14 12:14:00.814390
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "027381bce97c"
down_revision = "6fc7886d665d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
op.drop_column("user", "shortcut_enabled")

View File

@@ -1,36 +0,0 @@
"""add index to index_attempt.time_created
Revision ID: 0f7ff6d75b57
Revises: 369644546676
Create Date: 2025-01-10 14:01:14.067144
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "0f7ff6d75b57"
down_revision = "fec3db967bf7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_index_attempt_status"),
"index_attempt",
["status"],
unique=False,
)
op.create_index(
op.f("ix_index_attempt_time_created"),
"index_attempt",
["time_created"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")

View File

@@ -1,36 +0,0 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -1,80 +0,0 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -1,35 +0,0 @@
"""add composite index for index attempt time updated
Revision ID: 369644546676
Revises: 2955778aa44c
Create Date: 2025-01-08 15:38:17.224380
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "369644546676"
down_revision = "2955778aa44c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
text("time_updated DESC"),
],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
table_name="index_attempt",
)

View File

@@ -1,59 +0,0 @@
"""add back input prompts
Revision ID: 3c6531f32351
Revises: aeda5f2df4f6
Create Date: 2025-01-13 12:49:51.705235
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "3c6531f32351"
down_revision = "aeda5f2df4f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -40,6 +40,6 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -1,37 +0,0 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
from alembic import op
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get database connection
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -1,80 +0,0 @@
"""make categories labels and many to many
Revision ID: 6fc7886d665d
Revises: 3c6531f32351
Create Date: 2025-01-13 18:12:18.029112
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6fc7886d665d"
down_revision = "3c6531f32351"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename persona_category table to persona_label
op.rename_table("persona_category", "persona_label")
# Create the new association table
op.create_table(
"persona__persona_label",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("persona_label_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["persona_label_id"],
["persona_label.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
)
# Copy existing relationships to the new table
op.execute(
"""
INSERT INTO persona__persona_label (persona_id, persona_label_id)
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
"""
)
# Remove the old category_id column from persona table
op.drop_column("persona", "category_id")
def downgrade() -> None:
# Rename persona_label table back to persona_category
op.rename_table("persona_label", "persona_category")
# Add back the category_id column to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"persona_category_id_fkey",
"persona",
"persona_category",
["category_id"],
["id"],
)
# Copy the first label relationship back to the persona table
op.execute(
"""
UPDATE persona
SET category_id = (
SELECT persona_label_id
FROM persona__persona_label
WHERE persona__persona_label.persona_id = persona.id
LIMIT 1
)
"""
)
# Drop the association table
op.drop_table("persona__persona_label")

View File

@@ -1,72 +0,0 @@
"""Add SyncRecord
Revision ID: 97dbb53fa8c8
Revises: 369644546676
Create Date: 2025-01-11 19:39:50.426302
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "97dbb53fa8c8"
down_revision = "be2ab2aa50ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"sync_record",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column(
"sync_type",
sa.Enum(
"DOCUMENT_SET",
"USER_GROUP",
"CONNECTOR_DELETION",
name="synctype",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column(
"sync_status",
sa.Enum(
"IN_PROGRESS",
"SUCCESS",
"FAILED",
"CANCELED",
name="syncstatus",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Add index for fetch_latest_sync_record query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"sync_record",
["entity_id", "sync_type", "sync_start_time"],
)
# Add index for cleanup_sync_records query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_status",
"sync_record",
["entity_id", "sync_type", "sync_status"],
)
def downgrade() -> None:
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
op.drop_table("sync_record")

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -1,27 +0,0 @@
"""add pinned assistants
Revision ID: aeda5f2df4f6
Revises: c5eae4a75a1b
Create Date: 2025-01-09 16:04:10.770636
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "aeda5f2df4f6"
down_revision = "c5eae4a75a1b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
)
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
def downgrade() -> None:
op.drop_column("user", "pinned_assistants")

View File

@@ -1,38 +0,0 @@
"""fix_capitalization
Revision ID: be2ab2aa50ee
Revises: 369644546676
Create Date: 2025-01-10 13:13:26.228960
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "be2ab2aa50ee"
down_revision = "369644546676"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE document
SET
external_user_group_ids = ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
),
last_modified = NOW()
WHERE
external_user_group_ids IS NOT NULL
AND external_user_group_ids::text[] <> ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
)::text[]
"""
)
def downgrade() -> None:
# No way to cleanly persist the bad state through an upgrade/downgrade
# cycle, so we just pass
pass

View File

@@ -1,36 +0,0 @@
"""Add chat_message__standard_answer table
Revision ID: c5eae4a75a1b
Revises: 0f7ff6d75b57
Create Date: 2025-01-15 14:08:49.688998
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c5eae4a75a1b"
down_revision = "0f7ff6d75b57"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_message__standard_answer",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_message.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
)
def downgrade() -> None:
op.drop_table("chat_message__standard_answer")

View File

@@ -1,48 +0,0 @@
"""Add has_been_indexed to DocumentByConnectorCredentialPair
Revision ID: c7bf5721733e
Revises: fec3db967bf7
Create Date: 2025-01-13 12:39:05.831693
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7bf5721733e"
down_revision = "027381bce97c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# assume all existing rows have been indexed, no better approach
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
)
op.alter_column(
"document_by_connector_credential_pair",
"has_been_indexed",
nullable=False,
)
# Add index to optimize get_document_counts_for_cc_pairs query pattern
op.create_index(
"idx_document_cc_pair_counts",
"document_by_connector_credential_pair",
["connector_id", "credential_id", "has_been_indexed"],
unique=False,
)
def downgrade() -> None:
# Remove the index first before removing the column
op.drop_index(
"idx_document_cc_pair_counts",
table_name="document_by_connector_credential_pair",
)
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")

View File

@@ -1,33 +0,0 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -1,41 +0,0 @@
"""Add time_updated to UserGroup and DocumentSet
Revision ID: fec3db967bf7
Revises: 97dbb53fa8c8
Create Date: 2025-01-12 15:49:02.289100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fec3db967bf7"
down_revision = "97dbb53fa8c8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document_set",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"user_group",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_column("user_group", "time_last_modified_by_user")
op.drop_column("document_set", "time_last_modified_by_user")

View File

@@ -32,7 +32,6 @@ def perform_ttl_management_task(
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:

View File

@@ -1,73 +1,24 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_cloud_tasks_to_schedule = [
ee_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
},
"name": "autogenerate_usage_report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_tasks_to_schedule

View File

@@ -8,9 +8,6 @@ from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import mark_user_group_as_synced
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from onyx.background.celery.apps.app_base import task_logger
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
@@ -46,59 +43,24 @@ def monitor_usergroup_taskset(
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
usergroup_name = user_group.name
try:
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
except Exception as e:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.FAILED,
num_docs_synced=initial_count,
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
raise e
rug.reset()

View File

@@ -4,20 +4,6 @@ import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"

View File

@@ -345,8 +345,7 @@ def fetch_assistant_unique_users_total(
def user_can_view_assistant_stats(
db_session: Session, user: User | None, assistant_id: int
) -> bool:
# If user is None and auth is disabled, assume the user is an admin
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return True

View File

@@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import Document as DbDocument
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
).first()
prefixed_external_groups = [
build_ext_group_name_for_onyx(
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
).first()
prefixed_external_groups: set[str] = {
build_ext_group_name_for_onyx(
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)

View File

@@ -6,9 +6,8 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
@@ -62,10 +61,8 @@ def replace_user__ext_group_for_cc_pair(
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
all_group_members = batch_add_ext_perm_user_if_not_exists(
db_session=db_session, emails=list(all_group_member_emails)
)
delete_user__ext_group_for_cc_pair__no_commit(
@@ -87,14 +84,12 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=external_group_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
cc_pair_id=cc_pair_id,
)
)

View File

@@ -1,138 +1,27 @@
from collections.abc import Sequence
from datetime import datetime
import datetime
from typing import Literal
from sqlalchemy import asc
from sqlalchemy import BinaryExpression
from sqlalchemy import ColumnElement
from sqlalchemy import desc
from sqlalchemy import distinct
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import case
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import UnaryExpression
from onyx.configs.constants import QAFeedbackType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
def _build_filter_conditions(
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> list[ColumnElement]:
"""
Helper function to build all filter conditions for chat sessions.
Filters by start and end time, feedback type, and any sessions without messages.
start_time: Date from which to filter
end_time: Date to which to filter
feedback_filter: Feedback type to filter by
Returns: List of filter conditions
"""
conditions = []
if start_time is not None:
conditions.append(ChatSession.time_created >= start_time)
if end_time is not None:
conditions.append(ChatSession.time_created <= end_time)
if feedback_filter is not None:
feedback_subq = (
select(ChatMessage.chat_session_id)
.join(ChatMessageFeedback)
.group_by(ChatMessage.chat_session_id)
.having(
case(
(
case(
{literal(feedback_filter == QAFeedbackType.LIKE): True},
else_=False,
),
func.bool_and(ChatMessageFeedback.is_positive),
),
(
case(
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
else_=False,
),
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
),
else_=func.bool_or(ChatMessageFeedback.is_positive)
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
)
)
)
conditions.append(ChatSession.id.in_(feedback_subq))
return conditions
def get_total_filtered_chat_sessions_count(
db_session: Session,
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> int:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
stmt = (
select(func.count(distinct(ChatSession.id)))
.select_from(ChatSession)
.filter(*conditions)
)
return db_session.scalar(stmt) or 0
def get_page_of_chat_sessions(
start_time: datetime | None,
end_time: datetime | None,
db_session: Session,
page_num: int,
page_size: int,
feedback_filter: QAFeedbackType | None = None,
) -> Sequence[ChatSession]:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id)
.filter(*conditions)
.order_by(desc(ChatSession.time_created), ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
)
stmt = (
select(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options(
joinedload(ChatSession.user),
joinedload(ChatSession.persona),
contains_eager(ChatSession.messages).joinedload(
ChatMessage.chat_message_feedbacks
),
)
.order_by(
desc(ChatSession.time_created),
ChatSession.id,
asc(ChatMessage.id), # Ensure chronological message order
)
)
return db_session.scalars(stmt).unique().all()
SortByOptions = Literal["time_sent"]
def fetch_chat_sessions_eagerly_by_time(
start: datetime,
end: datetime,
start: datetime.datetime,
end: datetime.datetime,
db_session: Session,
limit: int | None = 500,
initial_time: datetime | None = None,
initial_time: datetime.datetime | None = None,
) -> list[ChatSession]:
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)

View File

@@ -7,7 +7,6 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import TokenRateLimitScope
from onyx.db.models import TokenRateLimit
from onyx.db.models import TokenRateLimit__UserGroup
@@ -21,11 +20,10 @@ from onyx.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return stmt
stmt = stmt.distinct()
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
@@ -48,12 +46,6 @@ def _add_user_filters(
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
# If user is None, this is an anonymous user and we should only show public token_rate_limits
if user is None:
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
@@ -111,10 +103,10 @@ def insert_user_group_token_rate_limit(
return token_limit
def fetch_user_group_token_rate_limits_for_user(
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,

View File

@@ -374,9 +374,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name, time_last_modified_by_user=func.now()
)
db_user_group = UserGroup(name=user_group.name)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -632,10 +630,6 @@ def update_user_group(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
db_session.commit()
return db_user_group
@@ -705,10 +699,7 @@ def delete_user_group_cc_pair_relationship__no_commit(
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")

View File

@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -25,9 +24,7 @@ _REQUEST_PAGINATION_LIMIT = 5000
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
viewspace_permissions = []
for permission_category in space_permissions:
@@ -70,13 +67,6 @@ def _get_server_space_permissions(
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
@@ -258,7 +248,6 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -267,12 +256,6 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@@ -342,7 +325,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -367,12 +350,6 @@ def confluence_doc_sync(
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
@@ -381,5 +358,4 @@ def confluence_doc_sync(
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -14,8 +14,6 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
return group_member_emails

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -45,12 +44,6 @@ def gmail_doc_sync(
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
@@ -118,18 +120,15 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_ids,
external_user_group_ids=group_emails,
is_public=public,
)
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -147,12 +146,6 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@@ -1,127 +1,16 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import AdminService
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
) -> dict[str, tuple[set[str], set[str]]]:
"""
This builds a map of drive ids to their members (group and user emails).
E.g. {
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
}
"""
drive_ids = google_drive_connector.get_all_drive_ids()
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
drive_service = get_drive_service(
google_drive_connector.creds,
google_drive_connector.primary_admin_email,
)
for drive_id in drive_ids:
group_emails: set[str] = set()
user_emails: set[str] = set()
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
return drive_id_to_members_map
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
"""
This gets all the group emails.
"""
group_emails: set[str] = set()
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email)",
):
group_emails.add(group["email"])
return group_emails
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
) -> dict[str, set[str]]:
"""
This maps group emails to their member emails.
"""
group_to_member_map: dict[str, set[str]] = {}
for group_email in group_emails:
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.add(member["email"])
group_to_member_map[group_email] = group_member_emails
return group_to_member_map
def _build_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, set[str]],
) -> list[ExternalUserGroup]:
onyx_groups: list[ExternalUserGroup] = []
# Convert all drive member definitions to onyx groups
# This is because having drive level access means you have
# irrevocable access to all the files in the drive.
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
all_member_emails: set[str] = user_emails
for group_email in group_emails:
all_member_emails.update(group_email_to_member_emails_map[group_email])
onyx_groups.append(
ExternalUserGroup(
id=drive_id,
user_emails=list(all_member_emails),
)
)
# Convert all group member definitions to onyx groups
for group_email, member_emails in group_email_to_member_emails_map.items():
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(member_emails),
)
)
return onyx_groups
def gdrive_group_sync(
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
@@ -130,23 +19,34 @@ def gdrive_group_sync(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)
# Get all drive members
drive_id_to_members_map = _get_drive_members(google_drive_connector)
onyx_groups: list[ExternalUserGroup] = []
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["email"]
# Get all group emails
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
# Gather group member emails
group_member_emails: list[str] = []
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
if not group_member_emails:
continue
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
)
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
)
return onyx_groups

View File

@@ -161,10 +161,7 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json

View File

@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -15,7 +14,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -136,7 +127,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
cc_pair=cc_pair,
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]

View File

@@ -1,9 +1,7 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
),
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,

View File

@@ -150,9 +150,9 @@ def _handle_standard_answers(
db_session=db_session,
description="",
user_id=None,
persona_id=(
slack_channel_config.persona.id if slack_channel_config.persona else 0
),
persona_id=slack_channel_config.persona.id
if slack_channel_config.persona
else 0,
onyxbot_flow=True,
slack_thread_id=slack_thread_id,
)
@@ -182,7 +182,7 @@ def _handle_standard_answers(
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
chat_message = create_new_chat_message(
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
@@ -191,13 +191,8 @@ def _handle_standard_answers(
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=False,
commit=True,
)
# attach the standard answers to the chat message
chat_message.standard_answers = [
standard_answer for standard_answer, _ in matching_standard_answers
]
db_session.commit()
update_emote_react(
emoji=DANSWER_REACT_EMOJI,

View File

@@ -1,23 +1,19 @@
import csv
import io
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Literal
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
@@ -27,15 +23,257 @@ from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
flow_type: SessionType
conversation_length: int
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}
def determine_flow_type(chat_session: ChatSession) -> SessionType:
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
def fetch_and_process_chat_session_history_minimal(
db_session: Session,
start: datetime,
end: datetime,
feedback_filter: QAFeedbackType | None = None,
limit: int | None = 500,
) -> list[ChatSessionMinimal]:
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
minimal_sessions = []
for chat_session in chat_sessions:
if not chat_session.messages:
continue
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
has_positive_feedback = any(
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
has_negative_feedback = any(
not feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
"mixed"
if has_positive_feedback and has_negative_feedback
else QAFeedbackType.LIKE
if has_positive_feedback
else QAFeedbackType.DISLIKE
if has_negative_feedback
else None
)
if feedback_filter:
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
continue
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
continue
flow_type = determine_flow_type(chat_session)
minimal_sessions.append(
ChatSessionMinimal(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=feedback_type,
flow_type=flow_type,
conversation_length=len(
[
m
for m in chat_session.messages
if m.message_type != MessageType.SYSTEM
]
),
)
)
return minimal_sessions
def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
@@ -81,7 +319,7 @@ def snapshot_from_chat_session(
except RuntimeError:
return None
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
flow_type = determine_flow_type(chat_session)
return ChatSessionSnapshot(
id=chat_session.id,
@@ -133,38 +371,22 @@ def get_user_chat_sessions(
@router.get("/admin/chat-session-history")
def get_chat_session_history(
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1),
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
start: datetime | None = None,
end: datetime | None = None,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
) -> list[ChatSessionMinimal]:
return fetch_and_process_chat_session_history_minimal(
db_session=db_session,
start_time=start_time,
end_time=end_time,
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
feedback_filter=feedback_type,
)
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
db_session=db_session,
start_time=start_time,
end_time=end_time,
feedback_filter=feedback_type,
)
return PaginatedReturn(
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@router.get("/admin/chat-session-history/{chat_session_id}")
def get_chat_session_admin(

View File

@@ -1,218 +0,0 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from onyx.auth.users import get_display_email
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import SessionType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
id: int
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
id=message.id,
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | None
flow_type: SessionType
conversation_length: int
@classmethod
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
list_of_message_feedbacks = [
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
]
session_feedback_type = None
if list_of_message_feedbacks:
if all(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.LIKE
elif not any(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.DISLIKE
else:
session_feedback_type = QAFeedbackType.MIXED
return cls(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=session_feedback_type,
flow_type=SessionType.SLACK
if chat_session.onyxbot_flow
else SessionType.CHAT,
conversation_length=len(
[
message
for message in chat_session.messages
if message.message_type != MessageType.SYSTEM
]
),
)
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}

View File

@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
seeded_logo_path: str | None = None
personas: list[PersonaUpsertRequest] | None = None
personas: list[CreatePersonaRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
@@ -128,7 +128,7 @@ def _seed_llms(
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:

View File

@@ -111,7 +111,6 @@ async def login_as_anonymous_user(
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie("fastapiusersauth")
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,

View File

@@ -5,7 +5,7 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
@@ -51,10 +51,8 @@ def get_group_token_limit_settings(
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
db_session=db_session,
group_id=group_id,
user=user,
for token_rate_limit in fetch_user_group_token_rate_limits(
db_session, group_id, user
)
]

View File

@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
access_type=cc_pair_relationship.cc_pair.access_type,
)
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current

View File

@@ -19,9 +19,6 @@ def prefix_external_group(ext_group_name: str) -> str:
return f"external_group:{ext_group_name}"
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
"""
External groups may collide across sources, every source needs its own prefix.
NOTE: the name is lowercased to handle case sensitivity for group names
"""
return f"{source.value}_{ext_group_name}".lower()
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@@ -42,17 +42,8 @@ class UserCreate(schemas.BaseUserCreate):
tenant_id: str | None = None
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole
class UserUpdate(schemas.BaseUserUpdate):
"""
Role updates are not allowed through the user update endpoint for security reasons
Role changes should be handled through a separate, admin-only process
"""
class AuthBackend(str, Enum):
REDIS = "redis"
POSTGRES = "postgres"

View File

@@ -33,8 +33,6 @@ from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
@@ -54,15 +52,13 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdateWithRole
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -78,7 +74,6 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
@@ -87,7 +82,6 @@ from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
@@ -215,7 +209,7 @@ def verify_email_domain(email: str) -> None:
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
@@ -245,8 +239,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
@@ -265,16 +261,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdateWithRole(
user_update = UserUpdate(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
else:
@@ -282,6 +278,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
@@ -583,14 +580,6 @@ def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
@@ -599,7 +588,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
def __init__(
self,
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
@@ -648,16 +637,9 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
await redis.delete(f"{self.key_prefix}{token}")
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
elif AUTH_BACKEND == AuthBackend.POSTGRES:
auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
else:
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):

View File

@@ -20,11 +20,10 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
@@ -101,10 +100,6 @@ def on_task_postrun(
if not task_id:
return
if task.name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# this is a cloud / all tenant task ... no postrun is needed
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
@@ -166,40 +161,14 @@ def on_task_postrun(
return
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
# But something is blocking set_start_method from working in the cloud unless
# force=True. so we use force=True as a fallback.
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
try:
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method exceptioned. Trying force=True..."
)
try:
multiprocessing.set_start_method(
"spawn", force=True
) # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method force=True exceptioned even with force=True."
)
logger.info(
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
)
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
@@ -281,6 +250,51 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
client = get_vespa_http_client()
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
@@ -318,8 +332,6 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
HttpxPool.close_all()
if not celery_is_worker_primary(sender):
return
@@ -468,13 +480,3 @@ def reset_tenant_id(
) -> None:
"""Signal handler to reset tenant ID in context var after task ends."""
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)

View File

@@ -13,7 +13,6 @@ from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -29,7 +28,7 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._try_updating_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
@@ -45,158 +44,105 @@ class DynamicTenantScheduler(PersistentScheduler):
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating task update")
try:
self._try_updating_schedule()
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Task update completed, reset reload timer")
return retval
def _generate_schedule(
self, tenant_ids: list[str] | list[None]
) -> dict[str, dict[str, Any]]:
"""Given a list of tenant id's, generates a new beat schedule for celery."""
logger.info("Fetching tasks to schedule")
def _update_tenant_tasks(self) -> None:
logger.info("Starting task update process")
try:
logger.info("Fetching all IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
new_schedule: dict[str, dict[str, Any]] = {}
if MULTI_TENANT:
# cloud tasks only need the single task beat across all tenants
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule",
"get_cloud_tasks_to_schedule",
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
cloud_tasks_to_schedule: list[
dict[str, Any]
] = get_cloud_tasks_to_schedule()
for task in cloud_tasks_to_schedule:
task_name = task["name"]
cloud_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": task.get("kwargs", {}),
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
cloud_task["options"] = options
new_schedule[task_name] = cloud_task
new_beat_schedule: dict[str, dict[str, Any]] = {}
# regular task beats are multiplied across all tenants
get_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
current_schedule = self.schedule.items()
tasks_to_schedule: list[dict[str, Any]] = get_tasks_to_schedule()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
for tenant_id in tenant_ids:
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
for task in tasks_to_schedule:
task_name = task["name"]
tenant_task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {tenant_task_name}")
tenant_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(
f"Adding options to task {tenant_task_name}: {options}"
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
continue
return new_schedule
if tenant_id not in existing_tenants:
logger.info(f"Processing new item: {tenant_id}")
def _try_updating_schedule(self) -> None:
"""Only updates the actual beat schedule on the celery app when it changes"""
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
logger.info("_try_updating_schedule starting")
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# there are no more per tenant beat tasks, so comment this out
# NOTE: we may not actualy need this scheduler any more and should
# test reverting to a regular beat schedule implementation
# Ensure changes are persisted
self.sync()
# current_tenants = set()
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
# if "_" in task_name:
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# # -> "12345678-abcd-efgh-ijkl-12345678"
# current_tenants.add(task_name.split("_")[-1])
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
# for tenant_id in tenant_ids:
# if tenant_id not in current_tenants:
# logger.info(f"Processing new tenant: {tenant_id}")
new_schedule = self._generate_schedule(tenant_ids)
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
logger.info(
"_try_updating_schedule: Current schedule is up to date, no changes needed"
)
return
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("_try_updating_schedule: Schedule updated successfully")
@staticmethod
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
"""Compare schedules to determine if an update is needed.
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect

View File

@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -49,20 +49,21 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
SqlEngine.init_engine(pool_size=4, max_overflow=12)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
@@ -50,25 +50,26 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -1,24 +1,21 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@@ -52,33 +49,21 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -1,95 +0,0 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
]
)

View File

@@ -1,4 +1,5 @@
import logging
import multiprocessing
from typing import Any
from typing import cast
@@ -6,7 +7,6 @@ from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.signals import celeryd_init
from celery.signals import worker_init
@@ -17,7 +17,7 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.indexing.utils import (
from onyx.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
@@ -73,20 +73,21 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
@@ -134,7 +135,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock # type: ignore
sender.primary_worker_lock = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)

View File

@@ -91,28 +91,6 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return False
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.
"""
task_set: set[str] = set()
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
task_id = task_dict.get("headers", {}).get("id")
if task_id:
task_set.add(task_id)
return task_set
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.

View File

@@ -1,13 +1,10 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -17,10 +14,8 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import Document
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
@@ -46,21 +41,14 @@ def _get_deletion_status(
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if not redis_connector.delete.fenced:
return None
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
def get_deletion_attempt_snapshot(
@@ -158,25 +146,3 @@ def celery_is_worker_primary(worker: Any) -> bool:
return True
return False
def httpx_init_vespa_pool(
max_keepalive_connections: int,
timeout: int = VESPA_REQUEST_TIMEOUT,
ssl_cert: str | None = None,
ssl_key: str | None = None,
) -> None:
httpx_cert = None
httpx_verify = False
if ssl_cert and ssl_key:
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
httpx_verify = True
HttpxPool.init_client(
name="vespa",
cert=httpx_cert,
verify=httpx_verify,
timeout=timeout,
http2=False,
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
)

View File

@@ -1,21 +0,0 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Monitoring worker specific settings
worker_concurrency = 1 # Single worker is sufficient for monitoring
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -2,259 +2,106 @@ from datetime import timedelta
from typing import Any
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
# we might be able to reduce this greatly if we can run a unified
# loop across all tenants rather than tasks per tenant
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
# tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
# cloud specific tasks
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
# remaining tasks are cloud generators for per tenant tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.LOWEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.LOW,
},
},
]
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
cloud_tasks_to_schedule.append(
tasks_to_schedule.append(
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
), # Check every hour
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
# tasks that run in either self-hosted on cloud
tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
tasks_to_schedule.extend(
[
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=15),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,17 +10,14 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncType
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
@@ -33,7 +30,6 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -45,7 +41,7 @@ def check_for_connector_deletion_task(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -117,21 +113,11 @@ def try_generate_document_cc_pair_cleanup_tasks(
# we need to load the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
return None
# set a basic fence to start
@@ -178,13 +164,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
except TaskDependencyError:
redis_connector.delete.set_fence(None)
raise

View File

@@ -3,18 +3,14 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
@@ -25,46 +21,31 @@ from ee.onyx.external_permissions.sync_params import (
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -76,9 +57,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
"""Jobs / utils for kicking off doc permissions sync tasks."""
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
@@ -113,21 +91,15 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -144,32 +116,14 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
payload_id = try_creating_permissions_sync_task(
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not payload_id:
if not tasks_created:
continue
task_logger.info(
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
# clear any permission fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
except Exception:
task_logger.exception(
"Exception while validating permission sync fences"
)
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -188,15 +142,13 @@ def try_creating_permissions_sync_task(
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> str | None:
"""Returns a randomized payload id on success.
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
LOCK_TIMEOUT = 30
payload_id: str | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
@@ -221,25 +173,6 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
# set a basic fence to start
redis_connector.permissions.set_active()
payload = RedisConnectorPermissionSyncPayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.permissions.set_fence(payload)
result = app.send_task(
OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
@@ -251,12 +184,12 @@ def try_creating_permissions_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# fill in the celery task id
redis_connector.permissions.set_active()
payload.celery_task_id = result.id
redis_connector.permissions.set_fence(payload)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
payload_id = payload.celery_task_id
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
@@ -264,7 +197,7 @@ def try_creating_permissions_sync_task(
if lock.owned():
lock.release()
return payload_id
return 1
@shared_task(
@@ -285,8 +218,6 @@ def connector_permission_sync_generator_task(
This task assumes that the task has already been properly fenced
"""
LoggerContextVars.reset()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
@@ -348,10 +279,7 @@ def connector_permission_sync_generator_task(
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
@@ -374,17 +302,12 @@ def connector_permission_sync_generator_task(
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
new_payload = RedisConnectorPermissionSyncPayload(
id=payload.id,
submitted=payload.submitted,
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
@@ -434,8 +357,6 @@ def update_external_document_permissions_task(
connector_id: int,
credential_id: int,
) -> bool:
start = time.monotonic()
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -465,330 +386,10 @@ def update_external_document_permissions_task(
document_ids=[doc_id],
)
elapsed = time.monotonic() - start
task_logger.info(
f"connector_id={connector_id} "
f"doc={doc_id} "
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} "
f"doc_id={doc_id}"
)
logger.exception("Error Syncing Document Permissions")
return False
return True
def validate_permission_sync_fences(
tenant_id: str | None,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
reserved_generator_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
validate_permission_sync_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
reserved_generator_tasks,
r,
r_celery,
)
return
def validate_permission_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
r: Redis,
r_celery: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_permission_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.permissions.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"validate_permission_sync_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
# either the generator task must be in flight or its subtasks must be
found = celery_find_task(
payload.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
r_celery,
)
if found:
# the celery task exists in the redis queue
redis_connector.permissions.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within a worker
redis_connector.permissions.set_active()
return
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
if member_str in reserved_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_permission_sync_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
if tasks_not_in_celery == 0:
redis_connector.permissions.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.permissions.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_permission_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
class PermissionSyncCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_connector.permissions.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"PermissionSyncCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"Permissions sync payload failed to validate. "
"Schema may have been updated."
)
return
if not payload:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"remaining={remaining} "
f"initial={initial}"
)
if remaining > 0:
return
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
task_logger.info(
f"Permissions sync finished: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"num_synced={initial}"
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial,
)
redis_connector.permissions.reset()

View File

@@ -1,4 +1,3 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -10,7 +9,6 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
@@ -22,12 +20,9 @@ from ee.onyx.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -38,18 +33,12 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -102,20 +91,15 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -147,7 +131,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
@@ -156,23 +139,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
# we want to run this less frequently than the overall task
# lock_beat.reacquire()
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# validate_external_group_sync_fences(
# tenant_id, self.app, r, r_celery, lock_beat
# )
# except Exception:
# task_logger.exception(
# "Exception while validating external group sync fences"
# )
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -215,12 +181,6 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
payload = RedisConnectorExternalGroupSyncPayload(
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
@@ -234,17 +194,13 @@ def try_creating_external_group_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
payload.celery_task_id = result.id
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
@@ -271,7 +227,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
External group sync task for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -279,65 +235,22 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
if not redis_connector.external_group_sync.fenced: # The fence must exist
raise ValueError(
f"connector_external_group_sync_generator_task - fence not found: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
payload = redis_connector.external_group_sync.payload # The payload must exist
if not payload:
raise ValueError(
"connector_external_group_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_external_group_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
time.sleep(1)
continue
logger.info(
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
@@ -372,26 +285,11 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
sync_status=SyncStatus.FAILED,
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
@@ -400,135 +298,3 @@ def connector_external_group_sync_generator_task(
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
def validate_external_group_sync_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_sync_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_sync_tasks,
r_celery,
db_session,
)
return
def validate_external_group_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.external_group_sync.fenced:
return
payload = redis_connector.external_group_sync.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
# if redis_connector_index.active():
# return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
"validate_external_group_sync_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
if found:
# the celery task exists in the redis queue
# redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
# redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
# if redis_connector_index.active():
# return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return

File diff suppressed because it is too large Load Diff

View File

@@ -1,522 +0,0 @@
import time
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
# now = time.monotonic()
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
# try:
# # this is unintuitive, but it checks if the parent pid is still running
# os.kill(self.parent_pid, 0)
# except Exception:
# logger.exception("IndexingCallback - parent pid check exceptioned")
# raise
# self.last_parent_check = now
try:
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str | None,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
for key_bytes in r_replica.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -14,16 +14,8 @@ from onyx.db.models import LLMProvider
def _process_model_list_response(model_list_json: Any) -> list[str]:
# Handle case where response is wrapped in a "data" field
if isinstance(model_list_json, dict):
if "data" in model_list_json:
model_list_json = model_list_json["data"]
elif "models" in model_list_json:
model_list_json = model_list_json["models"]
else:
raise ValueError(
"Invalid response from API - expected dict with 'data' or "
f"'models' field, got {type(model_list_json)}"
)
if isinstance(model_list_json, dict) and "data" in model_list_json:
model_list_json = model_list_json["data"]
if not isinstance(model_list_json, list):
raise ValueError(
@@ -35,18 +27,11 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
for item in model_list_json:
if isinstance(item, str):
model_names.append(item)
elif isinstance(item, dict):
if "model_name" in item:
model_names.append(item["model_name"])
elif "id" in item:
model_names.append(item["id"])
else:
raise ValueError(
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
)
elif isinstance(item, dict) and "model_name" in item:
model_names.append(item["model_name"])
else:
raise ValueError(
f"Invalid item in model list - expected string or dict, got {type(item)}"
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
)
return model_names
@@ -54,7 +39,6 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,

View File

@@ -1,829 +0,0 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
from typing import Literal
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings_list
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
_CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = (
"monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}"
)
_CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
)
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
_SYNC_START_LATENCY_KEY_FMT = (
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
)
_CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}"
_CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}"
_SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}"
_SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}"
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day
def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool:
"""Check if a metric has been emitted by checking for existence of Redis key"""
return bool(redis_std.exists(key))
class Metric(BaseModel):
key: str | None # only required if we need to store that we have emitted this metric
name: str
value: Any
tags: dict[str, str]
def log(self) -> None:
"""Log the metric in a standardized format"""
data = {
"metric": self.name,
"value": self.value,
"tags": self.tags,
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str | None) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
int_value = None
string_value = None
# NOTE: have to do bool first, since `isinstance(True, int)` is true
# e.g. bool is a subclass of int
if isinstance(self.value, bool):
bool_value = self.value
elif isinstance(self.value, int):
int_value = self.value
elif isinstance(self.value, float):
float_value = self.value
elif isinstance(self.value, str):
string_value = self.value
else:
task_logger.error(
f"Invalid metric value type: {type(self.value)} "
f"({self.value}) for metric {self.name}."
)
return
# don't send None values over the wire
data = {
k: v
for k, v in {
"metric_name": self.name,
"float_value": float_value,
"int_value": int_value,
"string_value": string_value,
"bool_value": bool_value,
"tags": self.tags,
}.items()
if v is not None
}
task_logger.info(f"Emitting metric: {data}")
optional_telemetry(
record_type=RecordType.METRIC,
data=data,
tenant_id=tenant_id,
)
def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
"""Collect metrics about queue lengths for different Celery queues"""
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"indexing_queue_length": "indexing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
}
for name, queue in queue_mappings.items():
metrics.append(
Metric(
key=None,
name=name,
value=celery_get_queue_length(queue, redis_celery),
tags={"queue": name},
)
)
return metrics
def _build_connector_start_latency_metric(
cc_pair: ConnectorCredentialPair,
recent_attempt: IndexAttempt,
second_most_recent_attempt: IndexAttempt | None,
redis_std: Redis,
) -> Metric | None:
if not recent_attempt.time_started:
return None
# check if we already emitted a metric for this index attempt
metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=recent_attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {recent_attempt.id} because it has already been "
"emitted"
)
return None
# Connector start latency
# first run case - we should start as soon as it's created
if not second_most_recent_attempt:
desired_start_time = cc_pair.connector.time_created
else:
if not cc_pair.connector.refresh_freq:
task_logger.error(
"Found non-initial index attempt for connector "
"without refresh_freq. This should never happen."
)
return None
desired_start_time = second_most_recent_attempt.time_updated + timedelta(
seconds=cc_pair.connector.refresh_freq
)
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
task_logger.info(
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
)
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
return Metric(
key=metric_key,
name="connector_start_latency",
value=start_latency,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
def _build_connector_final_metrics(
cc_pair: ConnectorCredentialPair,
recent_attempts: list[IndexAttempt],
redis_std: Redis,
) -> list[Metric]:
"""
Final metrics for connector index attempts:
- Boolean success/fail metric
- If success, emit:
* duration (seconds)
* doc_count
"""
metrics = []
for attempt in recent_attempts:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping final metrics for connector {cc_pair.connector.id} "
f"index attempt {attempt.id}, already emitted."
)
continue
# We only emit final metrics if the attempt is in a terminal state
if attempt.status not in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
# Not finished; skip
continue
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
success = attempt.status == IndexingStatus.SUCCESS
metrics.append(
Metric(
key=metric_key, # We'll mark the same key for any final metrics
name="connector_run_succeeded",
value=success,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
"status": attempt.status.value,
},
)
)
if success:
# Make sure we have valid time_started
if attempt.time_started and attempt.time_updated:
duration_seconds = (
attempt.time_updated - attempt.time_started
).total_seconds()
metrics.append(
Metric(
key=None, # No need for a new key, or you can reuse the same if you prefer
name="connector_index_duration_seconds",
value=duration_seconds,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
else:
task_logger.error(
f"Index attempt {attempt.id} succeeded but has missing time "
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
)
# For doc counts, choose whichever field is more relevant
doc_count = attempt.total_docs_indexed or 0
metrics.append(
Metric(
key=None,
name="connector_index_doc_count",
value=doc_count,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
return metrics
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about connector runs from the past hour"""
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all connector credential pairs
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
# Might be more than one search setting, or just one
active_search_settings_list = get_active_search_settings_list(db_session)
metrics = []
# If you want to process each cc_pair against each search setting:
for cc_pair in cc_pairs:
for search_settings in active_search_settings_list:
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.search_settings_id == search_settings.id,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
)
if not recent_attempts:
continue
most_recent_attempt = recent_attempts[0]
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
if one_hour_ago > most_recent_attempt.time_created:
continue
# Build a job_id for correlation
job_id = build_job_id(
"connector", str(cc_pair.id), str(most_recent_attempt.id)
)
# Add raw start time metric if available
if most_recent_attempt.time_started:
start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=most_recent_attempt.id,
)
metrics.append(
Metric(
key=start_time_key,
name="connector_start_time",
value=most_recent_attempt.time_started.timestamp(),
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
# Add raw end time metric if available and in terminal state
if (
most_recent_attempt.status.is_terminal()
and most_recent_attempt.time_updated
):
end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=most_recent_attempt.id,
)
metrics.append(
Metric(
key=end_time_key,
name="connector_end_time",
value=most_recent_attempt.time_updated.timestamp(),
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Connector run success/failure
final_metrics = _build_connector_final_metrics(
cc_pair, recent_attempts, redis_std
)
metrics.extend(final_metrics)
return metrics
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""
Collect metrics for document set and group syncing:
- Success/failure status
- Start latency (for doc sets / user groups)
- Duration & doc count (only if success)
- Throughput (docs/min) (only if success)
- Raw start/end times for each sync
"""
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records that ended in the last hour
recent_sync_records = db_session.scalars(
select(SyncRecord)
.where(SyncRecord.sync_end_time.isnot(None))
.where(SyncRecord.sync_end_time >= one_hour_ago)
.order_by(SyncRecord.sync_end_time.desc())
).all()
task_logger.info(
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
)
metrics = []
for sync_record in recent_sync_records:
# Build a job_id for correlation
job_id = build_job_id("sync_record", str(sync_record.id))
# Add raw start time metric
start_time_key = _SYNC_START_TIME_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
metrics.append(
Metric(
key=start_time_key,
name="sync_start_time",
value=sync_record.sync_start_time.timestamp(),
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
# Add raw end time metric if available
if sync_record.sync_end_time:
end_time_key = _SYNC_END_TIME_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
metrics.append(
Metric(
key=end_time_key,
name="sync_end_time",
value=sync_record.sync_end_time.timestamp(),
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
# Emit a SUCCESS/FAIL boolean metric
# Use a single Redis key to avoid re-emitting final metrics
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
if not _has_metric_been_emitted(redis_std, final_metric_key):
# Evaluate success
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
metrics.append(
Metric(
key=final_metric_key,
name="sync_run_succeeded",
value=sync_succeeded,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# If successful, emit additional metrics
if sync_succeeded:
if sync_record.sync_end_time and sync_record.sync_start_time:
duration_seconds = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds()
else:
task_logger.error(
f"Invalid times for sync record {sync_record.id}: "
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
)
duration_seconds = None
doc_count = sync_record.num_docs_synced or 0
sync_speed = None
if duration_seconds and duration_seconds > 0:
duration_mins = duration_seconds / 60.0
sync_speed = (
doc_count / duration_mins if duration_mins > 0 else None
)
# Emit duration, doc count, speed
if duration_seconds is not None:
metrics.append(
Metric(
key=final_metric_key,
name="sync_duration_seconds",
value=duration_seconds,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
else:
task_logger.error(
f"Invalid sync record {sync_record.id} with no duration"
)
metrics.append(
Metric(
key=final_metric_key,
name="sync_doc_count",
value=doc_count,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
if sync_speed is not None:
metrics.append(
Metric(
key=final_metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
else:
task_logger.error(
f"Invalid sync record {sync_record.id} with no duration"
)
# Emit start latency
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
if not _has_metric_been_emitted(redis_std, start_latency_key):
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
if entity is None:
task_logger.error(
f"Sync record of type {sync_record.sync_type} doesn't have an entity "
f"associated with it (id={sync_record.entity_id}). Skipping start latency metric."
)
# Calculate start latency in seconds:
# (actual sync start) - (last modified time)
if (
entity is not None
and entity.time_last_modified_by_user
and sync_record.sync_start_time
):
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
if start_latency < 0:
task_logger.error(
f"Negative start latency for sync record {sync_record.id} "
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
return metrics
def build_job_id(
job_type: Literal["connector", "sync_record"],
primary_id: str,
secondary_id: str | None = None,
) -> str:
if job_type == "connector":
if secondary_id is None:
raise ValueError(
"secondary_id (attempt_id) is required for connector job_type"
)
return f"connector:{primary_id}:attempt:{secondary_id}"
elif job_type == "sync_record":
return f"sync_record:{primary_id}"
@shared_task(
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
ignore_result=True,
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
time_limit=_MONITORING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
- Connector run metrics (start latency, success rate)
- Syncing speed metrics
- Worker status and task counts
"""
if tenant_id is not None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client(tenant_id=tenant_id)
lock_monitoring: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
timeout=_MONITORING_SOFT_TIME_LIMIT,
)
# these tasks should never overlap
if not lock_monitoring.acquire(blocking=False):
task_logger.info("Skipping monitoring task because it is already running")
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client(tenant_id=tenant_id)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
task_logger.exception("Error collecting background process metrics")
raise e
finally:
if lock_monitoring.owned():
lock_monitoring.release()
task_logger.info("Background monitoring task finished")
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
This check is expected to fail if a cloud alembic migration is currently running
across all tenants.
TODO: have the cloud migration script set an activity signal that this check
uses to know it doesn't make sense to run a check at the present time.
"""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_to_revision: dict[str, str | None] = {}
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
try:
# map each tenant_id to its revision
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
tenant_to_revision[tenant_id] = result_scalar
# get the total count of each revision
for k, v in tenant_to_revision.items():
if v is None:
continue
revision_counts[v] = revision_counts.get(v, 0) + 1
# get the revision with the most counts
sorted_revision_counts = sorted(
revision_counts.items(), key=lambda item: item[1], reverse=True
)
if len(sorted_revision_counts) == 0:
task_logger.error(
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
)
else:
top_revision, _ = sorted_revision_counts[0]
# build a list of out of date tenants
for k, v in tenant_to_revision.items():
if v == top_revision:
continue
out_of_date_tenants[k] = v
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud alembic check")
raise
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_alembic - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
if len(out_of_date_tenants) > 0:
task_logger.error(
f"Found out of date tenants: "
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
f"num_tenants={len(tenant_ids)} "
f"revision={top_revision}"
)
for k, v in islice(out_of_date_tenants.items(), 5):
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
else:
task_logger.info(
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -13,11 +13,11 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.tasks import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -25,30 +25,21 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
logger = setup_logger()
"""Jobs / utils for kicking off pruning tasks."""
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
@@ -87,7 +78,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -96,7 +86,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -113,10 +103,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
continue
@@ -213,14 +200,6 @@ def try_creating_prune_generator_task(
priority=OnyxCeleryPriority.LOW,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
except Exception:
@@ -252,8 +231,6 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
LoggerContextVars.reset()
pruning_ctx_dict = pruning_ctx.get()
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
pruning_ctx_dict["request_id"] = self.request.id
@@ -367,52 +344,3 @@ def connector_pruning_generator_task(
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
"""Monitoring pruning utils, called in monitor_vespa_sync"""
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
initial = redis_connector.prune.generator_complete
if initial is None:
return
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.PRUNING,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial,
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)

View File

@@ -28,35 +28,13 @@ class RetryDocumentIndex:
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
stop=stop_after_delay(STOP_AFTER),
)
def delete_single(
self,
doc_id: str,
*,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
def delete_single(self, doc_id: str) -> int:
return self.index.delete_single(doc_id)
@retry(
retry=retry_if_exception_type(httpx.ReadTimeout),
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
stop=stop_after_delay(STOP_AFTER),
)
def update_single(
self,
doc_id: str,
*,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:
return self.index.update_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
fields=fields,
)
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
return self.index.update_single(doc_id, fields)

View File

@@ -1,40 +1,27 @@
import time
from http import HTTPStatus
import httpx
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.document import fetch_chunk_count_for_document
from onyx.db.document import get_document
from onyx.db.document import get_document_connector_count
from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
@@ -75,18 +62,14 @@ def document_by_cc_pair_cleanup_task(
"""
task_logger.debug(f"Task start: doc={document_id}")
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
doc_index = get_default_document_index(
active_search_settings.primary,
active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
retry_index = RetryDocumentIndex(doc_index)
@@ -97,13 +80,7 @@ def document_by_cc_pair_cleanup_task(
# delete it from vespa and the db
action = "delete"
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
chunks_affected = retry_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
chunks_affected = retry_index.delete_single(document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
@@ -133,12 +110,7 @@ def document_by_cc_pair_cleanup_task(
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
chunks_affected = retry_index.update_single(document_id, fields=fields)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
@@ -156,13 +128,11 @@ def document_by_cc_pair_cleanup_task(
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
f"chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
@@ -217,78 +187,3 @@ def document_by_cc_pair_cleanup_task(
return False
return True
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
ignore_result=True,
trail=False,
bind=True,
)
def cloud_beat_task_generator(
self: Task,
task_name: str,
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
try:
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
# needed in the cloud
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
continue
self.app.send_task(
task_name,
kwargs=dict(
tenant_id=tenant_id,
),
queue=queue,
priority=priority,
expires=expires,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
finally:
if not lock_beat.owned():
task_logger.error(
"cloud_beat_task_generator - Lock not owned on completion"
)
redis_lock_dump(lock_beat, redis_client)
else:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -1,7 +1,6 @@
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
@@ -24,10 +23,6 @@ from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
monitor_ccpair_permissions_taskset,
)
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -38,6 +33,8 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
@@ -56,29 +53,24 @@ from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
@@ -98,7 +90,6 @@ logger = setup_logger()
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -287,21 +278,11 @@ def try_generate_document_set_sync_tasks(
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
document_set = get_document_set_by_id(
db_session=db_session,
document_set_id=document_set_id,
)
document_set = get_document_set_by_id(db_session, document_set_id)
if not document_set:
return None
if document_set.is_up_to_date:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
return None
# add tasks to celery and build up the task set to monitor in redis
@@ -330,13 +311,6 @@ def try_generate_document_set_sync_tasks(
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
return tasks_generated
@@ -358,9 +332,8 @@ def try_generate_user_group_sync_tasks(
return None
# race condition with the monitor/cleanup function if we use a cached result!
fetch_user_group = cast(
Callable[[Session, int], UserGroup | None],
fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"),
fetch_user_group = fetch_versioned_implementation(
"onyx.db.user_group", "fetch_user_group"
)
usergroup = fetch_user_group(db_session, usergroup_id)
@@ -368,13 +341,6 @@ def try_generate_user_group_sync_tasks(
return None
if usergroup.is_up_to_date:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
return None
# add tasks to celery and build up the task set to monitor in redis
@@ -402,16 +368,8 @@ def try_generate_user_group_sync_tasks(
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
return tasks_generated
@@ -461,13 +419,6 @@ def monitor_document_set_taskset(
f"remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
document_set = cast(
@@ -486,13 +437,6 @@ def monitor_document_set_taskset(
task_logger.info(
f"Successfully synced document set: document_set={document_set_id}"
)
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
rds.reset()
@@ -526,21 +470,10 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -612,29 +545,11 @@ def monitor_connector_deletion_taskset(
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
@@ -652,6 +567,83 @@ def monitor_connector_deletion_taskset(
redis_connector.delete.reset()
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
initial = redis_connector.prune.generator_complete
if initial is None:
return
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -660,7 +652,7 @@ def monitor_ccpair_indexing_taskset(
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"Connector indexing: could not parse composite_id from {fence_key}"
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
@@ -710,7 +702,6 @@ def monitor_ccpair_indexing_taskset(
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
@@ -756,7 +747,7 @@ def monitor_ccpair_indexing_taskset(
)
except Exception:
task_logger.exception(
"Connector indexing - Transient exception marking index attempt as failed: "
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -766,20 +757,6 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
@@ -796,20 +773,11 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
@shared_task(
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
ignore_result=True,
soft_time_limit=300,
bind=True,
)
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes various long running tasks.
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
now. Should change that at some point.
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
For many tasks, the count is 0, that means all tasks finished and we should clean up.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
@@ -825,17 +793,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# Replica usage notes
#
# False negatives are OK. (aka fail to to see a key that exists on the master).
# We simply skip the monitoring work and it will be caught on the next pass.
#
# False positives are not OK, and are possible if we clear a fence on the master and
# then read from the replica. In this case, monitoring work could be done on a fence
# that no longer exists. To avoid this, we scan from the replica, but double check
# the result on the master.
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
@@ -895,19 +852,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
# scan and monitor activity to completion
phase_start = time.monotonic()
lock_beat.reacquire()
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
timings["connector"] = time.monotonic() - phase_start
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
timings["connector_deletion"] = time.monotonic() - phase_start
@@ -917,82 +872,70 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["documentset"] = time.monotonic() - phase_start
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["usergroup"] = time.monotonic() - phase_start
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["pruning"] = time.monotonic() - phase_start
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["indexing"] = time.monotonic() - phase_start
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["permissions"] = time.monotonic() - phase_start
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
return False
except Exception:
task_logger.exception("monitor_vespa_sync exceptioned.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
@@ -1019,15 +962,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
retry_index = RetryDocumentIndex(doc_index)
@@ -1053,12 +992,7 @@ def vespa_metadata_sync_task(
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
chunks_affected = retry_index.update_single(document_id, fields)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
@@ -1073,16 +1007,9 @@ def vespa_metadata_sync_task(
# r = get_redis_client(tenant_id=tenant_id)
# r.delete(redis_syncing_key)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(

View File

@@ -1,15 +0,0 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.monitoring import celery_app
return celery_app
app = get_app()

View File

@@ -4,10 +4,9 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
from multiprocessing import Process
from typing import Any
from typing import Literal
from typing import Optional
@@ -47,9 +46,7 @@ def _initializer(
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(
pool_size=4, max_overflow=12, pool_recycle=60, pool_pre_ping=True
)
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
return func(*args, **kwargs)
@@ -66,7 +63,7 @@ class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
id: int
process: Optional["SpawnProcess"] = None
process: Optional["Process"] = None
def cancel(self) -> bool:
return self.release()
@@ -134,10 +131,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
process = Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -4,7 +4,6 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
@@ -12,7 +11,6 @@ from onyx.background.indexing.tracer import OnyxTracer
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
@@ -23,19 +21,16 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
@@ -80,8 +75,7 @@ def _get_connector_runner(
# it will never succeed
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=attempt.connector_credential_pair.id,
attempt.connector_credential_pair.id, db_session
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
@@ -102,17 +96,10 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
for doc in doc_batch:
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
@@ -128,9 +115,6 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@@ -140,21 +124,9 @@ class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
search_settings_status: IndexModelStatus
def _run_indexing(
db_session: Session,
index_attempt_id: int,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
@@ -168,77 +140,61 @@ def _run_indexing(
"""
start_time = time.time()
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
)
if index_attempt_start.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
# search_settings = index_attempt_start.search_settings
db_connector = index_attempt_start.connector_credential_pair.connector
db_credential = index_attempt_start.connector_credential_pair.credential
ctx = RunIndexingContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
credential_id=db_credential.id,
source=db_connector.source,
earliest_index_time=(
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
),
from_beginning=index_attempt_start.from_beginning,
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary=(
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
),
search_settings_status=index_attempt_start.search_settings.status,
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
last_successful_index_time = (
ctx.earliest_index_time
if ctx.from_beginning
else get_last_successful_attempt_time(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
earliest_index=ctx.earliest_index_time,
search_settings=index_attempt_start.search_settings,
db_session=db_session_temp,
)
)
search_settings = index_attempt.search_settings
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
callback=callback,
)
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = search_settings.status == IndexModelStatus.PRESENT
# Indexing is only done into one index at a time
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
callback=callback,
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt_id,
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
earliest_index_time = (
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
)
last_successful_index_time = (
earliest_index_time
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
@@ -246,8 +202,8 @@ def _run_indexing(
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
connector_id=db_connector.id,
credential_id=db_credential.id,
)
batch_num = 0
@@ -263,31 +219,21 @@ def _run_indexing(
source_type=db_connector.source,
)
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_loop_start = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop_start:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt_loop_start,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
@@ -302,38 +248,24 @@ def _run_indexing(
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
with get_session_with_tenant(tenant_id) as db_session_temp:
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
db_session.refresh(db_cc_pair)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and search_settings.status != IndexModelStatus.FUTURE
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
if (
(
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(
db_session_temp, index_attempt_id
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt.status}"
)
if not index_attempt_loop:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
batch_description = []
@@ -357,15 +289,16 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
index_pipeline_result = indexing_pipeline(
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@@ -374,19 +307,18 @@ def _run_indexing(
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
@@ -399,35 +331,33 @@ def _run_indexing(
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
run_end_dt = window_end
if ctx.is_primary:
with get_session_with_tenant(tenant_id) as db_session_temp:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
mark_attempt_canceled(
index_attempt.id,
db_session,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
@@ -442,29 +372,23 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or (
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
)
or (
index_attempt_loop is not None
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
)
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
@@ -487,58 +411,56 @@ def _run_indexing(
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason="All batches exceptioned.",
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
with get_session_with_tenant(tenant_id) as db_session_temp:
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session_temp,
)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session,
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=run_end_dt,
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
run_dt=run_end_dt,
)
def run_indexing_entrypoint(
@@ -558,35 +480,27 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
credential_id = attempt.connector_credential_pair.credential_id
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
_run_indexing(db_session, attempt, tenant_id, callback)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"

View File

@@ -12,10 +12,10 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
from onyx.chat.prompt_builder.build import default_build_system_message
from onyx.chat.prompt_builder.build import default_build_user_message
from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
@@ -212,6 +212,19 @@ class Answer:
current_llm_call
) or ([], [])
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
@@ -252,13 +265,11 @@ class Answer:
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
single_message_history=self.single_message_history,
),
message_history=self.message_history,
llm_config=self.llm.config,
raw_user_query=self.question,
raw_user_uploaded_files=self.latest_query_files or [],
single_message_history=self.single_message_history,
raw_user_text=self.question,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)

View File

@@ -25,7 +25,7 @@ from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.prompts import get_prompts_by_ids
from onyx.db.persona import get_prompts_by_ids
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest

View File

@@ -7,7 +7,7 @@ from langchain_core.messages import BaseMessage
from onyx.chat.models import ResponsePart
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
@@ -260,8 +261,13 @@ class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
@@ -270,6 +276,20 @@ class AnswerStyleConfig(BaseModel):
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed

View File

@@ -254,7 +254,6 @@ def _get_force_search_settings(
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
new_msg_req.query_override is not None,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
@@ -303,11 +302,6 @@ def stream_chat_message_objects(
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
# a string which represents the history of a conversation. Used in cases like
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
# messages.
# NOTE: is not stored in the database at all.
single_message_history: str | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -426,7 +420,9 @@ def stream_chat_message_objects(
)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
@@ -498,6 +494,14 @@ def stream_chat_message_objects(
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@@ -703,7 +707,6 @@ def stream_chat_message_objects(
],
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
single_message_history=single_message_history,
)
reference_db_search_docs = None

View File

@@ -15,12 +15,10 @@ from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
@@ -32,45 +30,30 @@ def default_build_system_message(
prompt_config: PromptConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
if not tag_handled_prompt:
if not system_prompt:
return None
return SystemMessage(content=tag_handled_prompt)
system_msg = SystemMessage(content=system_prompt)
return system_msg
def default_build_user_message(
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
single_message_history: str | None = None,
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
) -> HumanMessage:
history_block = (
HISTORY_BLOCK.format(history_str=single_message_history)
if single_message_history
else ""
)
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
history_block=history_block,
task_prompt=prompt_config.task_prompt,
user_query=user_query,
task_prompt=prompt_config.task_prompt, user_query=user_query
)
if prompt_config.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg
@@ -81,8 +64,7 @@ class AnswerPromptBuilder:
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_query: str,
raw_user_uploaded_files: list[InMemoryChatFile],
raw_user_text: str,
single_message_history: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -91,7 +73,6 @@ class AnswerPromptBuilder:
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_config = llm_config
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
@@ -100,29 +81,21 @@ class AnswerPromptBuilder:
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(
message_history,
exclude_images=not model_supports_image_input(
self.llm_config.model_name,
self.llm_config.model_provider,
),
)
) = translate_history_to_basemessages(message_history)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
self.single_message_history = single_message_history
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
# used for building a new prompt after a tool-call
self.raw_user_query = raw_user_query
self.raw_user_uploaded_files = raw_user_uploaded_files
self.single_message_history = single_message_history
self.raw_user_message = raw_user_text
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:

View File

@@ -1,13 +1,12 @@
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Persona
from onyx.db.prompts import get_default_prompt
from onyx.db.persona import get_default_prompt__read_only
from onyx.db.search_settings import get_multilingual_expansion
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -21,9 +20,9 @@ from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import build_task_prompt_reminders
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from onyx.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
@@ -98,12 +97,11 @@ def compute_max_document_tokens(
def compute_max_document_tokens_for_persona(
db_session: Session,
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
@@ -127,11 +125,10 @@ def build_citations_system_message(
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.include_citations:
system_prompt += REQUIRE_CITATION_STATEMENT
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt, prompt_config, add_additional_info_if_no_tag=True
)
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
return SystemMessage(content=tag_handled_prompt)
return SystemMessage(content=system_prompt)
def build_citations_user_message(
@@ -147,7 +144,9 @@ def build_citations_user_message(
)
history_block = (
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
HISTORY_BLOCK.format(history_str=history_message) + "\n"
if history_message
else ""
)
query, img_urls = message_to_prompt_and_imgs(message)

View File

@@ -9,8 +9,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
def _build_strong_llm_quotes_prompt(
@@ -39,11 +39,10 @@ def _build_strong_llm_quotes_prompt(
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
).strip()
tag_handled_prompt = handle_onyx_date_awareness(
full_prompt, prompt, add_additional_info_if_no_tag=True
)
if prompt.datetime_aware:
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
return HumanMessage(content=tag_handled_prompt)
return HumanMessage(content=full_prompt)
def build_quotes_user_message(

View File

@@ -7,11 +7,30 @@ from onyx.db.models import ChatMessage
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_onyx_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
exclude_images: bool = False,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
@@ -19,9 +38,7 @@ def translate_onyx_msg_to_langchain(
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
@@ -35,12 +52,9 @@ def translate_onyx_msg_to_langchain(
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
exclude_images: bool = False,
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_onyx_msg_to_langchain(msg, exclude_images)
for msg in history
if msg.token_count != 0
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts

View File

@@ -5,7 +5,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.build import LLMCall
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message
@@ -62,7 +62,7 @@ class ToolResponseHandler:
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.raw_user_query,
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
@@ -76,7 +76,7 @@ class ToolResponseHandler:
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.raw_user_query,
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
@@ -95,7 +95,7 @@ class ToolResponseHandler:
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.raw_user_query,
query=llm_call.prompt_builder.raw_user_message,
llm=llm,
)
if available_tools_and_args

View File

@@ -3,7 +3,6 @@ import os
import urllib.parse
from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
@@ -18,7 +17,6 @@ APP_PORT = 8080
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
#####
# User Facing Features Configs
@@ -56,12 +54,12 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value)
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS")
or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS")
or 86400 * 7
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
@@ -93,12 +91,6 @@ OAUTH_CLIENT_SECRET = (
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
# By default, this is set to match the Redis expiry time for consistency.
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
@@ -200,8 +192,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# this assumes that other redis settings remain the same as the primary
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
@@ -478,12 +468,6 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
# Enable multi-threaded embedding model calls for parallel processing
# Note: only applies for API-based embedding models
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
)
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
@@ -617,8 +601,3 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes
MOCK_LLM_RESPONSE = (
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
)

View File

@@ -1,6 +1,6 @@
import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"

View File

@@ -47,7 +47,6 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@@ -79,8 +78,6 @@ KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
# NOTE: we use this timeout / 4 in various places to refresh a lock
# might be worth separating this timeout into separate timeouts for each situation
CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -200,7 +197,6 @@ class SessionType(str, Enum):
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
MIXED = "mixed" # User likes some answers and dislikes other, used for chat session metrics
class SearchFeedbackType(str, Enum):
@@ -264,9 +260,6 @@ class OnyxCeleryQueues:
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
# Monitoring queue
MONITORING = "monitoring"
class OnyxRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
@@ -281,7 +274,6 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -294,14 +286,9 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
class OnyxCeleryPriority(int, Enum):
@@ -312,19 +299,7 @@ class OnyxCeleryPriority(int, Enum):
LOWEST = auto()
# a prefix used to distinguish system wide tasks in the cloud
ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
# the tenant id we use for system level redis operations
ONYX_CLOUD_TENANT_ID = "cloud"
class OnyxCeleryTask:
DEFAULT = "celery"
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
@@ -332,10 +307,7 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@@ -1,7 +1,3 @@
import contextvars
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
@@ -24,9 +20,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# These field types are considered metadata by default when
# treat_all_non_attachment_fields_as_metadata is False
DEFAULT_METADATA_FIELD_TYPES = {
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
@@ -64,42 +60,21 @@ class AirtableConnector(LoadConnector):
self,
base_id: str,
table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self._airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
treat_all_non_attachment_fields_as_metadata
)
self.airtable_client: AirtableApi | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
@property
def airtable_client(self) -> AirtableApi:
if not self._airtable_client:
raise AirtableClientNotSetUpError()
return self._airtable_client
def _extract_field_values(
self,
field_id: str,
field_name: str,
field_info: Any,
field_type: str,
base_id: str,
table_id: str,
view_id: str | None,
record_id: str,
) -> list[tuple[str, str]]:
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
"""
Extract value(s) + links from a field regardless of its type.
Attachments are represented as multiple sections, and therefore
returned as a list of tuples (value, link).
Extract value(s) from a field regardless of its type.
Returns either a single string or list of strings for attachments.
"""
if field_info is None:
return []
@@ -110,11 +85,8 @@ class AirtableConnector(LoadConnector):
if field_type == "multipleRecordLinks":
return []
# default link to use for non-attachment fields
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
if field_type == "multipleAttachments":
attachment_texts: list[tuple[str, str]] = []
attachment_texts: list[str] = []
for attachment in field_info:
url = attachment.get("url")
filename = attachment.get("filename", "")
@@ -127,37 +99,16 @@ class AirtableConnector(LoadConnector):
backoff=2,
max_delay=10,
)
def get_attachment_with_retry(url: str, record_id: str) -> bytes | None:
try:
attachment_response = requests.get(url)
attachment_response.raise_for_status()
def get_attachment_with_retry(url: str) -> bytes | None:
attachment_response = requests.get(url)
if attachment_response.status_code == 200:
return attachment_response.content
except requests.exceptions.HTTPError as e:
if e.response.status_code == 410:
logger.info(f"Refreshing attachment for {filename}")
# Re-fetch the record to get a fresh URL
refreshed_record = self.airtable_client.table(
base_id, table_id
).get(record_id)
for refreshed_attachment in refreshed_record["fields"][
field_name
]:
if refreshed_attachment.get("filename") == filename:
new_url = refreshed_attachment.get("url")
if new_url:
attachment_response = requests.get(new_url)
attachment_response.raise_for_status()
return attachment_response.content
return None
logger.error(f"Failed to refresh attachment for {filename}")
raise
attachment_content = get_attachment_with_retry(url, record_id)
attachment_content = get_attachment_with_retry(url)
if attachment_content:
try:
file_ext = get_file_ext(filename)
attachment_id = attachment["id"]
attachment_text = extract_file_text(
BytesIO(attachment_content),
filename,
@@ -165,20 +116,7 @@ class AirtableConnector(LoadConnector):
extension=file_ext,
)
if attachment_text:
# slightly nicer loading experience if we can specify the view ID
if view_id:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
else:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
attachment_texts.append(
(f"{filename}:\n{attachment_text}", attachment_link)
)
attachment_texts.append(f"{filename}:\n{attachment_text}")
except Exception as e:
logger.warning(
f"Failed to process attachment {filename}: {str(e)}"
@@ -193,31 +131,23 @@ class AirtableConnector(LoadConnector):
combined.append(collab_name)
if collab_email:
combined.append(f"({collab_email})")
return [(" ".join(combined) if combined else str(field_info), default_link)]
return [" ".join(combined) if combined else str(field_info)]
if isinstance(field_info, list):
return [(item, default_link) for item in field_info]
return [str(item) for item in field_info]
return [(str(field_info), default_link)]
return [str(field_info)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata.
When treat_all_non_attachment_fields_as_metadata is True, all fields except
attachments are treated as metadata. Otherwise, only fields with types listed
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
if self.treat_all_non_attachment_fields_as_metadata:
return field_type.lower() != "multipleattachments"
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
def _process_field(
self,
field_id: str,
field_name: str,
field_info: Any,
field_type: str,
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
"""
@@ -235,22 +165,12 @@ class AirtableConnector(LoadConnector):
return [], {}
# Get the value(s) for the field
field_value_and_links = self._extract_field_values(
field_id=field_id,
field_name=field_name,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
)
if len(field_value_and_links) == 0:
field_values = self._get_field_value(field_info, field_type)
if len(field_values) == 0:
return [], {}
# Determine if it should be metadata or a section
if self._should_be_metadata(field_type):
field_values = [value for value, _ in field_value_and_links]
if len(field_values) > 1:
return [], {field_name: field_values}
return [], {field_name: field_values[0]}
@@ -258,7 +178,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
Section(
link=link,
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
text=(
f"{field_name}:\n"
"------------------------\n"
@@ -266,7 +186,7 @@ class AirtableConnector(LoadConnector):
"------------------------"
),
)
for text, link in field_value_and_links
for text in field_values
]
return sections, {}
@@ -275,7 +195,7 @@ class AirtableConnector(LoadConnector):
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document | None:
) -> Document:
"""Process a single Airtable record into a Document.
Args:
@@ -299,35 +219,23 @@ class AirtableConnector(LoadConnector):
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
view_id = table_schema.views[0].id if table_schema.views else None
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
logger.debug(
f"Processing field '{field_name}' of type '{field_type}' "
f"for record '{record_id}'."
)
field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
view_id=view_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
if not sections:
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
@@ -364,47 +272,18 @@ class AirtableConnector(LoadConnector):
primary_field_name = field.name
break
logger.info(f"Starting to process Airtable records for {table.name}.")
record_documents: list[Document] = []
for record in records:
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
record_documents.append(document)
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record: dict[Future, RecordDict] = {}
for record in batch_records:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
future_to_record[
executor.submit(
current_context.run,
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
] = record
# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
record = future_to_record[future]
try:
document = future.result()
if document:
record_documents.append(document)
except Exception as e:
logger.exception(f"Failed to process record {record['id']}")
raise e
yield record_documents
record_documents = []
# Yield any remaining records
if record_documents:
yield record_documents

View File

@@ -232,29 +232,20 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
}
# Get labels
label_dicts = (
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
)
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
label_dicts = confluence_object["metadata"]["labels"]["results"]
page_labels = [label["name"] for label in label_dicts]
if page_labels:
doc_metadata["labels"] = page_labels
# Get last modified and author email
version_dict = confluence_object.get("version", {})
last_modified = (
datetime_from_string(version_dict.get("when"))
if version_dict.get("when")
else None
)
author_email = version_dict.get("by", {}).get("email")
title = confluence_object.get("title", "Untitled Document")
last_modified = datetime_from_string(confluence_object["version"]["when"])
author_email = confluence_object["version"].get("by", {}).get("email")
return Document(
id=object_url,
sections=[Section(link=object_url, text=object_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=title,
semantic_identifier=confluence_object["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author_email)] if author_email else None

View File

@@ -121,7 +121,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
@@ -135,6 +134,32 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@@ -179,41 +204,24 @@ class OnyxConfluence(Confluence):
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS in url_suffix:
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
if (
raw_response.status_code == 500
and limit > _MINIMUM_PAGINATION_LIMIT
):
new_limit = limit // 2
logger.warning(
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
f"Reducing limit from {limit} to {new_limit} and trying again."
)
url_suffix = url_suffix.replace(
f"limit={limit}", f"limit={new_limit}"
)
limit = new_limit
continue
raise e
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
raise e
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
try:
next_response = raw_response.json()
@@ -328,62 +336,6 @@ class OnyxConfluence(Confluence):
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def get_all_space_permissions_server(
self,
space_key: str,
) -> list[dict[str, Any]]:
"""
This is a confluence server specific method that can be used to
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
TODO: Make this call these endpoints for newer confluence versions:
- /rest/api/space/{spaceKey}/permissions
- /rest/api/space/{spaceKey}/permissions/anonymous
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
"jsonrpc": "2.0",
"method": "getSpacePermissionSets",
"id": 7,
"params": [space_key],
}
response = self.post(url, data=data)
logger.debug(f"jsonrpc response: {response}")
if not response.get("result"):
logger.warning(
f"No jsonrpc response for space permissions for space {space_key}"
f"\nResponse: {response}"
)
return response.get("result", [])
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _validate_connector_configuration(
credentials: dict[str, Any],

Some files were not shown because too many files have changed in this diff Show More