mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
3 Commits
error_supp
...
bugfix/imp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b7f40a21c | ||
|
|
3578f47263 | ||
|
|
229defe7bf |
7
.github/pull_request_template.md
vendored
7
.github/pull_request_template.md
vendored
@@ -1,14 +1,11 @@
|
||||
## Description
|
||||
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
@@ -67,7 +67,6 @@ jobs:
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -60,8 +60,6 @@ jobs:
|
||||
push: true
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
2
.github/workflows/pr-chromatic-tests.yml
vendored
2
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -8,8 +8,6 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.17.0
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,6 +37,22 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -46,7 +62,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
29
.github/workflows/pr-linear-check.yml
vendored
29
.github/workflows/pr-linear-check.yml
vendored
@@ -1,29 +0,0 @@
|
||||
name: Ensure PR references Linear
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
linear-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Looking for a checked override: "[x] Override Linear Check"
|
||||
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
|
||||
echo "Override box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Otherwise, fail the run
|
||||
echo "No Linear link or override found in the PR description."
|
||||
exit 1
|
||||
@@ -39,12 +39,6 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
1
.vscode/env_template.txt
vendored
1
.vscode/env_template.txt
vendored
@@ -29,7 +29,6 @@ REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
29
.vscode/launch.template.jsonc
vendored
29
.vscode/launch.template.jsonc
vendored
@@ -28,7 +28,6 @@
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@@ -52,8 +51,7 @@
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@@ -271,31 +269,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
|
||||
@@ -17,10 +17,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. CD into web, run "npm i" followed by npm run dev.
|
||||
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
7. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ There are two editions of Onyx:
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
|
||||
@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""add shortcut option for users
|
||||
|
||||
Revision ID: 027381bce97c
|
||||
Revises: 6fc7886d665d
|
||||
Create Date: 2025-01-14 12:14:00.814390
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "027381bce97c"
|
||||
down_revision = "6fc7886d665d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "shortcut_enabled")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add index to index_attempt.time_created
|
||||
|
||||
Revision ID: 0f7ff6d75b57
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 14:01:14.067144
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0f7ff6d75b57"
|
||||
down_revision = "fec3db967bf7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_index_attempt_status"),
|
||||
"index_attempt",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_index_attempt_time_created"),
|
||||
"index_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
|
||||
|
||||
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,59 +0,0 @@
|
||||
"""add back input prompts
|
||||
|
||||
Revision ID: 3c6531f32351
|
||||
Revises: aeda5f2df4f6
|
||||
Create Date: 2025-01-13 12:49:51.705235
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3c6531f32351"
|
||||
down_revision = "aeda5f2df4f6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -40,6 +40,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
|
||||
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||
op.drop_column("persona", "category_id")
|
||||
op.drop_table("persona_category")
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
@@ -1,80 +0,0 @@
|
||||
"""make categories labels and many to many
|
||||
|
||||
Revision ID: 6fc7886d665d
|
||||
Revises: 3c6531f32351
|
||||
Create Date: 2025-01-13 18:12:18.029112
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6fc7886d665d"
|
||||
down_revision = "3c6531f32351"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename persona_category table to persona_label
|
||||
op.rename_table("persona_category", "persona_label")
|
||||
|
||||
# Create the new association table
|
||||
op.create_table(
|
||||
"persona__persona_label",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("persona_label_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_label_id"],
|
||||
["persona_label.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
|
||||
)
|
||||
|
||||
# Copy existing relationships to the new table
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO persona__persona_label (persona_id, persona_label_id)
|
||||
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Remove the old category_id column from persona table
|
||||
op.drop_column("persona", "category_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename persona_label table back to persona_category
|
||||
op.rename_table("persona_label", "persona_category")
|
||||
|
||||
# Add back the category_id column to persona table
|
||||
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"persona_category_id_fkey",
|
||||
"persona",
|
||||
"persona_category",
|
||||
["category_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Copy the first label relationship back to the persona table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET category_id = (
|
||||
SELECT persona_label_id
|
||||
FROM persona__persona_label
|
||||
WHERE persona__persona_label.persona_id = persona.id
|
||||
LIMIT 1
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the association table
|
||||
op.drop_table("persona__persona_label")
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Add SyncRecord
|
||||
|
||||
Revision ID: 97dbb53fa8c8
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-11 19:39:50.426302
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "97dbb53fa8c8"
|
||||
down_revision = "be2ab2aa50ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"sync_record",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"sync_type",
|
||||
sa.Enum(
|
||||
"DOCUMENT_SET",
|
||||
"USER_GROUP",
|
||||
"CONNECTOR_DELETION",
|
||||
name="synctype",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"sync_status",
|
||||
sa.Enum(
|
||||
"IN_PROGRESS",
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
name="syncstatus",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
|
||||
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add index for fetch_latest_sync_record query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_start_time",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_start_time"],
|
||||
)
|
||||
|
||||
# Add index for cleanup_sync_records query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_status",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
|
||||
op.drop_table("sync_record")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add pinned assistants
|
||||
|
||||
Revision ID: aeda5f2df4f6
|
||||
Revises: c5eae4a75a1b
|
||||
Create Date: 2025-01-09 16:04:10.770636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "aeda5f2df4f6"
|
||||
down_revision = "c5eae4a75a1b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "pinned_assistants")
|
||||
@@ -1,38 +0,0 @@
|
||||
"""fix_capitalization
|
||||
|
||||
Revision ID: be2ab2aa50ee
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 13:13:26.228960
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "be2ab2aa50ee"
|
||||
down_revision = "369644546676"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET
|
||||
external_user_group_ids = ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
),
|
||||
last_modified = NOW()
|
||||
WHERE
|
||||
external_user_group_ids IS NOT NULL
|
||||
AND external_user_group_ids::text[] <> ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
)::text[]
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No way to cleanly persist the bad state through an upgrade/downgrade
|
||||
# cycle, so we just pass
|
||||
pass
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Add chat_message__standard_answer table
|
||||
|
||||
Revision ID: c5eae4a75a1b
|
||||
Revises: 0f7ff6d75b57
|
||||
Create Date: 2025-01-15 14:08:49.688998
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c5eae4a75a1b"
|
||||
down_revision = "0f7ff6d75b57"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_message__standard_answer",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_message_id"],
|
||||
["chat_message.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("chat_message__standard_answer")
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Add has_been_indexed to DocumentByConnectorCredentialPair
|
||||
|
||||
Revision ID: c7bf5721733e
|
||||
Revises: fec3db967bf7
|
||||
Create Date: 2025-01-13 12:39:05.831693
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c7bf5721733e"
|
||||
down_revision = "027381bce97c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# assume all existing rows have been indexed, no better approach
|
||||
op.add_column(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
|
||||
)
|
||||
op.alter_column(
|
||||
"document_by_connector_credential_pair",
|
||||
"has_been_indexed",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add index to optimize get_document_counts_for_cc_pairs query pattern
|
||||
op.create_index(
|
||||
"idx_document_cc_pair_counts",
|
||||
"document_by_connector_credential_pair",
|
||||
["connector_id", "credential_id", "has_been_indexed"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the index first before removing the column
|
||||
op.drop_index(
|
||||
"idx_document_cc_pair_counts",
|
||||
table_name="document_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Add time_updated to UserGroup and DocumentSet
|
||||
|
||||
Revision ID: fec3db967bf7
|
||||
Revises: 97dbb53fa8c8
|
||||
Create Date: 2025-01-12 15:49:02.289100
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fec3db967bf7"
|
||||
down_revision = "97dbb53fa8c8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document_set",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_group", "time_last_modified_by_user")
|
||||
op.drop_column("document_set", "time_last_modified_by_user")
|
||||
@@ -32,7 +32,6 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -1,73 +1,24 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
},
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
|
||||
@@ -8,9 +8,6 @@ from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -46,59 +43,24 @@ def monitor_usergroup_taskset(
|
||||
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
usergroup_name = user_group.name
|
||||
try:
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=initial_count,
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
rug.reset()
|
||||
|
||||
@@ -4,20 +4,6 @@ import os
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
# are requested from the OIDC provider. Currently used when passing
|
||||
# over access tokens to tool calls and the tool needs more scopes
|
||||
OIDC_SCOPE_OVERRIDE: list[str] | None = None
|
||||
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
|
||||
|
||||
if _OIDC_SCOPE_OVERRIDE:
|
||||
try:
|
||||
OIDC_SCOPE_OVERRIDE = [
|
||||
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
|
||||
|
||||
|
||||
@@ -345,8 +345,7 @@ def fetch_assistant_unique_users_total(
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User | None, assistant_id: int
|
||||
) -> bool:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Document as DbDocument
|
||||
|
||||
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups = [
|
||||
build_ext_group_name_for_onyx(
|
||||
prefix_group_w_source(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups: set[str] = {
|
||||
build_ext_group_name_for_onyx(
|
||||
prefix_group_w_source(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
|
||||
@@ -6,9 +6,8 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -62,10 +61,8 @@ def replace_user__ext_group_for_cc_pair(
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
# NOTE: this function handles case sensitivity for emails
|
||||
emails=list(all_group_member_emails),
|
||||
all_group_members = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session, emails=list(all_group_member_emails)
|
||||
)
|
||||
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
@@ -87,14 +84,12 @@ def replace_user__ext_group_for_cc_pair(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=external_group_id,
|
||||
external_user_group_id=prefix_group_w_source(
|
||||
external_group.id, source
|
||||
),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,138 +1,27 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
import datetime
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy import BinaryExpression
|
||||
from sqlalchemy import ColumnElement
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import distinct
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import case
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.expression import UnaryExpression
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _build_filter_conditions(
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
feedback_filter: QAFeedbackType | None,
|
||||
) -> list[ColumnElement]:
|
||||
"""
|
||||
Helper function to build all filter conditions for chat sessions.
|
||||
Filters by start and end time, feedback type, and any sessions without messages.
|
||||
start_time: Date from which to filter
|
||||
end_time: Date to which to filter
|
||||
feedback_filter: Feedback type to filter by
|
||||
Returns: List of filter conditions
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
if start_time is not None:
|
||||
conditions.append(ChatSession.time_created >= start_time)
|
||||
if end_time is not None:
|
||||
conditions.append(ChatSession.time_created <= end_time)
|
||||
|
||||
if feedback_filter is not None:
|
||||
feedback_subq = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatMessageFeedback)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.having(
|
||||
case(
|
||||
(
|
||||
case(
|
||||
{literal(feedback_filter == QAFeedbackType.LIKE): True},
|
||||
else_=False,
|
||||
),
|
||||
func.bool_and(ChatMessageFeedback.is_positive),
|
||||
),
|
||||
(
|
||||
case(
|
||||
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
|
||||
else_=False,
|
||||
),
|
||||
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
|
||||
),
|
||||
else_=func.bool_or(ChatMessageFeedback.is_positive)
|
||||
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
|
||||
)
|
||||
)
|
||||
)
|
||||
conditions.append(ChatSession.id.in_(feedback_subq))
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
def get_total_filtered_chat_sessions_count(
|
||||
db_session: Session,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
feedback_filter: QAFeedbackType | None,
|
||||
) -> int:
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
stmt = (
|
||||
select(func.count(distinct(ChatSession.id)))
|
||||
.select_from(ChatSession)
|
||||
.filter(*conditions)
|
||||
)
|
||||
return db_session.scalar(stmt) or 0
|
||||
|
||||
|
||||
def get_page_of_chat_sessions(
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
db_session: Session,
|
||||
page_num: int,
|
||||
page_size: int,
|
||||
feedback_filter: QAFeedbackType | None = None,
|
||||
) -> Sequence[ChatSession]:
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
|
||||
subquery = (
|
||||
select(ChatSession.id)
|
||||
.filter(*conditions)
|
||||
.order_by(desc(ChatSession.time_created), ChatSession.id)
|
||||
.limit(page_size)
|
||||
.offset(page_num * page_size)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.join(subquery, ChatSession.id == subquery.c.id)
|
||||
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.options(
|
||||
joinedload(ChatSession.user),
|
||||
joinedload(ChatSession.persona),
|
||||
contains_eager(ChatSession.messages).joinedload(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
desc(ChatSession.time_created),
|
||||
ChatSession.id,
|
||||
asc(ChatMessage.id), # Ensure chronological message order
|
||||
)
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
SortByOptions = Literal["time_sent"]
|
||||
|
||||
|
||||
def fetch_chat_sessions_eagerly_by_time(
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
db_session: Session,
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
initial_time: datetime.datetime | None = None,
|
||||
) -> list[ChatSession]:
|
||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
from onyx.db.models import TokenRateLimit
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -21,8 +20,8 @@ from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
stmt = stmt.distinct()
|
||||
@@ -48,12 +47,6 @@ def _add_user_filters(
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all token_rate_limits in the groups the user curates
|
||||
"""
|
||||
|
||||
# If user is None, this is an anonymous user and we should only show public token_rate_limits
|
||||
if user is None:
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
@@ -111,10 +104,10 @@ def insert_user_group_token_rate_limit(
|
||||
return token_limit
|
||||
|
||||
|
||||
def fetch_user_group_token_rate_limits_for_user(
|
||||
def fetch_user_group_token_rate_limits(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User | None,
|
||||
user: User | None = None,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
|
||||
@@ -374,9 +374,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
)
|
||||
db_user_group = UserGroup(name=user_group.name)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
@@ -632,10 +630,6 @@ def update_user_group(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -705,10 +699,7 @@ def delete_user_group_cc_pair_relationship__no_commit(
|
||||
connector_credential_pair_id matches the given cc_pair_id.
|
||||
|
||||
Should be used very carefully (only for connectors that are being deleted)."""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -258,7 +257,6 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -267,12 +265,6 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -342,7 +334,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -367,12 +359,6 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -381,5 +367,4 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -45,12 +44,6 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
@@ -118,18 +120,15 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_ids,
|
||||
external_user_group_ids=group_emails,
|
||||
is_public=public,
|
||||
)
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -147,12 +146,6 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -1,127 +1,16 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import AdminService
|
||||
from onyx.connectors.google_utils.resources import get_admin_service
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_drive_members(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
) -> dict[str, tuple[set[str], set[str]]]:
|
||||
"""
|
||||
This builds a map of drive ids to their members (group and user emails).
|
||||
E.g. {
|
||||
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
|
||||
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
|
||||
}
|
||||
"""
|
||||
drive_ids = google_drive_connector.get_all_drive_ids()
|
||||
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
|
||||
drive_service = get_drive_service(
|
||||
google_drive_connector.creds,
|
||||
google_drive_connector.primary_admin_email,
|
||||
)
|
||||
|
||||
for drive_id in drive_ids:
|
||||
group_emails: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
|
||||
return drive_id_to_members_map
|
||||
|
||||
|
||||
def _get_all_groups(
|
||||
admin_service: AdminService,
|
||||
google_domain: str,
|
||||
) -> set[str]:
|
||||
"""
|
||||
This gets all the group emails.
|
||||
"""
|
||||
group_emails: set[str] = set()
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_domain,
|
||||
fields="groups(email)",
|
||||
):
|
||||
group_emails.add(group["email"])
|
||||
return group_emails
|
||||
|
||||
|
||||
def _map_group_email_to_member_emails(
|
||||
admin_service: AdminService,
|
||||
group_emails: set[str],
|
||||
) -> dict[str, set[str]]:
|
||||
"""
|
||||
This maps group emails to their member emails.
|
||||
"""
|
||||
group_to_member_map: dict[str, set[str]] = {}
|
||||
for group_email in group_emails:
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
):
|
||||
group_member_emails.add(member["email"])
|
||||
|
||||
group_to_member_map[group_email] = group_member_emails
|
||||
return group_to_member_map
|
||||
|
||||
|
||||
def _build_onyx_groups(
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
|
||||
group_email_to_member_emails_map: dict[str, set[str]],
|
||||
) -> list[ExternalUserGroup]:
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
|
||||
# Convert all drive member definitions to onyx groups
|
||||
# This is because having drive level access means you have
|
||||
# irrevocable access to all the files in the drive.
|
||||
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
|
||||
all_member_emails: set[str] = user_emails
|
||||
for group_email in group_emails:
|
||||
all_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=drive_id,
|
||||
user_emails=list(all_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
# Convert all group member definitions to onyx groups
|
||||
for group_email, member_emails in group_email_to_member_emails_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -130,23 +19,34 @@ def gdrive_group_sync(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
# Get all drive members
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_drive_connector.google_domain,
|
||||
fields="groups(email)",
|
||||
):
|
||||
# The id is the group email
|
||||
group_email = group["email"]
|
||||
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_groups(
|
||||
admin_service, google_drive_connector.google_domain
|
||||
)
|
||||
# Gather group member emails
|
||||
group_member_emails: list[str] = []
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
):
|
||||
group_member_emails.append(member["email"])
|
||||
|
||||
# Map group emails to their members
|
||||
group_email_to_member_emails_map = _map_group_email_to_member_emails(
|
||||
admin_service, all_group_emails
|
||||
)
|
||||
if not group_member_emails:
|
||||
continue
|
||||
|
||||
# Convert the maps to onyx groups
|
||||
onyx_groups = _build_onyx_groups(
|
||||
drive_id_to_members_map=drive_id_to_members_map,
|
||||
group_email_to_member_emails_map=group_email_to_member_emails_map,
|
||||
)
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -161,10 +161,7 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
|
||||
|
||||
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
|
||||
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -15,7 +14,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -136,7 +127,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
),
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
|
||||
@@ -150,9 +150,9 @@ def _handle_standard_answers(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=(
|
||||
slack_channel_config.persona.id if slack_channel_config.persona else 0
|
||||
),
|
||||
persona_id=slack_channel_config.persona.id
|
||||
if slack_channel_config.persona
|
||||
else 0,
|
||||
onyxbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -182,7 +182,7 @@ def _handle_standard_answers(
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
chat_message = create_new_chat_message(
|
||||
_ = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
@@ -191,13 +191,8 @@ def _handle_standard_answers(
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
commit=True,
|
||||
)
|
||||
# attach the standard answers to the chat message
|
||||
chat_message.standard_answers = [
|
||||
standard_answer for standard_answer, _ in matching_standard_answers
|
||||
]
|
||||
db_session.commit()
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
from ee.onyx.db.query_history import get_page_of_chat_sessions
|
||||
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import MessageSnapshot
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -27,15 +23,257 @@ from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
"""A subset of the info present in `SearchDoc`"""
|
||||
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
|
||||
|
||||
class MessageSnapshot(BaseModel):
|
||||
message: str
|
||||
message_type: MessageType
|
||||
documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
time_created: datetime
|
||||
|
||||
@classmethod
|
||||
def build(cls, message: ChatMessage) -> "MessageSnapshot":
|
||||
latest_messages_feedback_obj = (
|
||||
message.chat_message_feedbacks[-1]
|
||||
if len(message.chat_message_feedbacks) > 0
|
||||
else None
|
||||
)
|
||||
feedback_type = (
|
||||
(
|
||||
QAFeedbackType.LIKE
|
||||
if latest_messages_feedback_obj.is_positive
|
||||
else QAFeedbackType.DISLIKE
|
||||
)
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
feedback_text = (
|
||||
latest_messages_feedback_obj.feedback_text
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
message=message.message,
|
||||
message_type=message.message_type,
|
||||
documents=[
|
||||
AbridgedSearchDoc(
|
||||
document_id=document.document_id,
|
||||
semantic_identifier=document.semantic_id,
|
||||
link=document.link,
|
||||
)
|
||||
for document in message.search_docs
|
||||
],
|
||||
feedback_type=feedback_type,
|
||||
feedback_text=feedback_text,
|
||||
time_created=message.time_sent,
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionMinimal(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||
flow_type: SessionType
|
||||
conversation_length: int
|
||||
|
||||
|
||||
class ChatSessionSnapshot(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
|
||||
class QuestionAnswerPairSnapshot(BaseModel):
|
||||
chat_session_id: UUID
|
||||
# 1-indexed message number in the chat_session
|
||||
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
|
||||
message_pair_num: int
|
||||
user_message: str
|
||||
ai_response: str
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
@classmethod
|
||||
def from_chat_session_snapshot(
|
||||
cls,
|
||||
chat_session_snapshot: ChatSessionSnapshot,
|
||||
) -> list["QuestionAnswerPairSnapshot"]:
|
||||
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
|
||||
for ind in range(1, len(chat_session_snapshot.messages), 2):
|
||||
message_pairs.append(
|
||||
(
|
||||
chat_session_snapshot.messages[ind - 1],
|
||||
chat_session_snapshot.messages[ind],
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
cls(
|
||||
chat_session_id=chat_session_snapshot.id,
|
||||
message_pair_num=ind + 1,
|
||||
user_message=user_message.message,
|
||||
ai_response=ai_message.message,
|
||||
retrieved_documents=ai_message.documents,
|
||||
feedback_type=ai_message.feedback_type,
|
||||
feedback_text=ai_message.feedback_text,
|
||||
persona_name=chat_session_snapshot.assistant_name,
|
||||
user_email=get_display_email(chat_session_snapshot.user_email),
|
||||
time_created=user_message.time_created,
|
||||
flow_type=chat_session_snapshot.flow_type,
|
||||
)
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"retrieved_documents": "|".join(
|
||||
[
|
||||
doc.link or doc.semantic_identifier
|
||||
for doc in self.retrieved_documents
|
||||
]
|
||||
),
|
||||
"feedback_type": self.feedback_type.value if self.feedback_type else "",
|
||||
"feedback_text": self.feedback_text or "",
|
||||
"persona_name": self.persona_name,
|
||||
"user_email": self.user_email,
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
|
||||
|
||||
def determine_flow_type(chat_session: ChatSession) -> SessionType:
|
||||
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history_minimal(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
feedback_filter: QAFeedbackType | None = None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionMinimal]:
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
|
||||
minimal_sessions = []
|
||||
for chat_session in chat_sessions:
|
||||
if not chat_session.messages:
|
||||
continue
|
||||
|
||||
first_user_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.USER
|
||||
),
|
||||
"",
|
||||
)
|
||||
first_ai_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.ASSISTANT
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
has_positive_feedback = any(
|
||||
feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
)
|
||||
|
||||
has_negative_feedback = any(
|
||||
not feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
)
|
||||
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
|
||||
"mixed"
|
||||
if has_positive_feedback and has_negative_feedback
|
||||
else QAFeedbackType.LIKE
|
||||
if has_positive_feedback
|
||||
else QAFeedbackType.DISLIKE
|
||||
if has_negative_feedback
|
||||
else None
|
||||
)
|
||||
|
||||
if feedback_filter:
|
||||
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
|
||||
continue
|
||||
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
|
||||
continue
|
||||
|
||||
flow_type = determine_flow_type(chat_session)
|
||||
|
||||
minimal_sessions.append(
|
||||
ChatSessionMinimal(
|
||||
id=chat_session.id,
|
||||
user_email=get_display_email(
|
||||
chat_session.user.email if chat_session.user else None
|
||||
),
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
assistant_id=chat_session.persona_id,
|
||||
assistant_name=(
|
||||
chat_session.persona.name if chat_session.persona else None
|
||||
),
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=feedback_type,
|
||||
flow_type=flow_type,
|
||||
conversation_length=len(
|
||||
[
|
||||
m
|
||||
for m in chat_session.messages
|
||||
if m.message_type != MessageType.SYSTEM
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return minimal_sessions
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
@@ -81,7 +319,7 @@ def snapshot_from_chat_session(
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
|
||||
flow_type = determine_flow_type(chat_session)
|
||||
|
||||
return ChatSessionSnapshot(
|
||||
id=chat_session.id,
|
||||
@@ -133,38 +371,22 @@ def get_user_chat_sessions(
|
||||
|
||||
@router.get("/admin/chat-session-history")
|
||||
def get_chat_session_history(
|
||||
page_num: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1),
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
) -> list[ChatSessionMinimal]:
|
||||
return fetch_and_process_chat_session_history_minimal(
|
||||
db_session=db_session,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start=start
|
||||
or (
|
||||
datetime.now(tz=timezone.utc) - timedelta(days=30)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.now(tz=timezone.utc),
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
|
||||
db_session=db_session,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
"""A subset of the info present in `SearchDoc`"""
|
||||
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
|
||||
|
||||
class MessageSnapshot(BaseModel):
|
||||
id: int
|
||||
message: str
|
||||
message_type: MessageType
|
||||
documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
time_created: datetime
|
||||
|
||||
@classmethod
|
||||
def build(cls, message: ChatMessage) -> "MessageSnapshot":
|
||||
latest_messages_feedback_obj = (
|
||||
message.chat_message_feedbacks[-1]
|
||||
if len(message.chat_message_feedbacks) > 0
|
||||
else None
|
||||
)
|
||||
feedback_type = (
|
||||
(
|
||||
QAFeedbackType.LIKE
|
||||
if latest_messages_feedback_obj.is_positive
|
||||
else QAFeedbackType.DISLIKE
|
||||
)
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
feedback_text = (
|
||||
latest_messages_feedback_obj.feedback_text
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=message.id,
|
||||
message=message.message,
|
||||
message_type=message.message_type,
|
||||
documents=[
|
||||
AbridgedSearchDoc(
|
||||
document_id=document.document_id,
|
||||
semantic_identifier=document.semantic_id,
|
||||
link=document.link,
|
||||
)
|
||||
for document in message.search_docs
|
||||
],
|
||||
feedback_type=feedback_type,
|
||||
feedback_text=feedback_text,
|
||||
time_created=message.time_sent,
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionMinimal(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | None
|
||||
flow_type: SessionType
|
||||
conversation_length: int
|
||||
|
||||
@classmethod
|
||||
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
|
||||
first_user_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.USER
|
||||
),
|
||||
"",
|
||||
)
|
||||
first_ai_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.ASSISTANT
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
list_of_message_feedbacks = [
|
||||
feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
]
|
||||
session_feedback_type = None
|
||||
if list_of_message_feedbacks:
|
||||
if all(list_of_message_feedbacks):
|
||||
session_feedback_type = QAFeedbackType.LIKE
|
||||
elif not any(list_of_message_feedbacks):
|
||||
session_feedback_type = QAFeedbackType.DISLIKE
|
||||
else:
|
||||
session_feedback_type = QAFeedbackType.MIXED
|
||||
|
||||
return cls(
|
||||
id=chat_session.id,
|
||||
user_email=get_display_email(
|
||||
chat_session.user.email if chat_session.user else None
|
||||
),
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
assistant_id=chat_session.persona_id,
|
||||
assistant_name=(
|
||||
chat_session.persona.name if chat_session.persona else None
|
||||
),
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=session_feedback_type,
|
||||
flow_type=SessionType.SLACK
|
||||
if chat_session.onyxbot_flow
|
||||
else SessionType.CHAT,
|
||||
conversation_length=len(
|
||||
[
|
||||
message
|
||||
for message in chat_session.messages
|
||||
if message.message_type != MessageType.SYSTEM
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionSnapshot(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
|
||||
class QuestionAnswerPairSnapshot(BaseModel):
|
||||
chat_session_id: UUID
|
||||
# 1-indexed message number in the chat_session
|
||||
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
|
||||
message_pair_num: int
|
||||
user_message: str
|
||||
ai_response: str
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
@classmethod
|
||||
def from_chat_session_snapshot(
|
||||
cls,
|
||||
chat_session_snapshot: ChatSessionSnapshot,
|
||||
) -> list["QuestionAnswerPairSnapshot"]:
|
||||
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
|
||||
for ind in range(1, len(chat_session_snapshot.messages), 2):
|
||||
message_pairs.append(
|
||||
(
|
||||
chat_session_snapshot.messages[ind - 1],
|
||||
chat_session_snapshot.messages[ind],
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
cls(
|
||||
chat_session_id=chat_session_snapshot.id,
|
||||
message_pair_num=ind + 1,
|
||||
user_message=user_message.message,
|
||||
ai_response=ai_message.message,
|
||||
retrieved_documents=ai_message.documents,
|
||||
feedback_type=ai_message.feedback_type,
|
||||
feedback_text=ai_message.feedback_text,
|
||||
persona_name=chat_session_snapshot.assistant_name,
|
||||
user_email=get_display_email(chat_session_snapshot.user_email),
|
||||
time_created=user_message.time_created,
|
||||
flow_type=chat_session_snapshot.flow_type,
|
||||
)
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"retrieved_documents": "|".join(
|
||||
[
|
||||
doc.link or doc.semantic_identifier
|
||||
for doc in self.retrieved_documents
|
||||
]
|
||||
),
|
||||
"feedback_type": self.feedback_type.value if self.feedback_type else "",
|
||||
"feedback_text": self.feedback_text or "",
|
||||
"persona_name": self.persona_name,
|
||||
"user_email": self.user_email,
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
|
||||
llms: list[LLMProviderUpsertRequest] | None = None
|
||||
admin_user_emails: list[str] | None = None
|
||||
seeded_logo_path: str | None = None
|
||||
personas: list[PersonaUpsertRequest] | None = None
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
enterprise_settings: EnterpriseSettings | None = None
|
||||
|
||||
@@ -128,7 +128,7 @@ def _seed_llms(
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
|
||||
@@ -111,7 +111,6 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie("fastapiusersauth")
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -51,10 +51,8 @@ def get_group_token_limit_settings(
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
|
||||
db_session=db_session,
|
||||
group_id=group_id,
|
||||
user=user,
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits(
|
||||
db_session, group_id, user
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
@@ -19,9 +19,6 @@ def prefix_external_group(ext_group_name: str) -> str:
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""
|
||||
External groups may collide across sources, every source needs its own prefix.
|
||||
NOTE: the name is lowercased to handle case sensitivity for group names
|
||||
"""
|
||||
return f"{source.value}_{ext_group_name}".lower()
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
|
||||
@@ -42,17 +42,8 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
|
||||
class AuthBackend(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
@@ -33,8 +33,6 @@ from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -54,15 +52,13 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -78,7 +74,6 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
@@ -87,7 +82,6 @@ from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -215,7 +209,7 @@ def verify_email_domain(email: str) -> None:
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
async def create(
|
||||
@@ -245,8 +239,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
@@ -265,16 +261,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdateWithRole(
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
@@ -282,6 +278,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
@@ -583,14 +580,6 @@ def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
@@ -599,7 +588,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
@@ -648,16 +637,9 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
elif AUTH_BACKEND == AuthBackend.POSTGRES:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
@@ -20,11 +20,10 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -101,10 +100,6 @@ def on_task_postrun(
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
if task.name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# this is a cloud / all tenant task ... no postrun is needed
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
@@ -166,40 +161,14 @@ def on_task_postrun(
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
|
||||
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
|
||||
# But something is blocking set_start_method from working in the cloud unless
|
||||
# force=True. so we use force=True as a fallback.
|
||||
|
||||
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
|
||||
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
|
||||
|
||||
try:
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Multiprocessing set_start_method exceptioned. Trying force=True..."
|
||||
)
|
||||
try:
|
||||
multiprocessing.set_start_method(
|
||||
"spawn", force=True
|
||||
) # fork is unsafe, set to spawn
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Multiprocessing set_start_method force=True exceptioned even with force=True."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||
)
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
@@ -281,6 +250,51 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
@@ -318,8 +332,6 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
HttpxPool.close_all()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
@@ -468,13 +480,3 @@ def reset_tenant_id(
|
||||
) -> None:
|
||||
"""Signal handler to reset tenant ID in context var after task ends."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -29,7 +28,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._try_updating_schedule()
|
||||
self._update_tenant_tasks()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
@@ -45,158 +44,105 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _generate_schedule(
|
||||
self, tenant_ids: list[str] | list[None]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Given a list of tenant id's, generates a new beat schedule for celery."""
|
||||
logger.info("Fetching tasks to schedule")
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting task update process")
|
||||
try:
|
||||
logger.info("Fetching all IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
new_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if MULTI_TENANT:
|
||||
# cloud tasks only need the single task beat across all tenants
|
||||
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule",
|
||||
"get_cloud_tasks_to_schedule",
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
cloud_tasks_to_schedule: list[
|
||||
dict[str, Any]
|
||||
] = get_cloud_tasks_to_schedule()
|
||||
for task in cloud_tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
cloud_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
cloud_task["options"] = options
|
||||
new_schedule[task_name] = cloud_task
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# regular task beats are multiplied across all tenants
|
||||
get_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
tasks_to_schedule: list[dict[str, Any]] = get_tasks_to_schedule()
|
||||
existing_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
for task in tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
tenant_task_name = f"{task['name']}-{tenant_id}"
|
||||
|
||||
logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
tenant_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
new_schedule[tenant_task_name] = tenant_task
|
||||
continue
|
||||
|
||||
return new_schedule
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new item: {tenant_id}")
|
||||
|
||||
def _try_updating_schedule(self) -> None:
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
logger.debug(f"Creating task configuration for {task_name}")
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
logger.info("_try_updating_schedule starting")
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_beat_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
|
||||
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
|
||||
logger.info(
|
||||
"_try_updating_schedule: Current schedule is up to date, no changes needed"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("_try_updating_schedule: Schedule updated successfully")
|
||||
|
||||
@staticmethod
|
||||
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
|
||||
"""Compare schedules to determine if an update is needed.
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
if current_tasks != new_tasks:
|
||||
return False
|
||||
|
||||
return True
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
logger.debug("Comparing current and new schedules")
|
||||
current_tasks = set(name for name, _ in current_schedule)
|
||||
new_tasks = set(new_schedule.keys())
|
||||
needs_update = current_tasks != new_tasks
|
||||
logger.debug(f"Schedule update needed: {needs_update}")
|
||||
return needs_update
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@@ -49,20 +49,21 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
@@ -50,25 +50,26 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
|
||||
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -52,33 +49,21 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -6,7 +7,6 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
@@ -17,7 +17,7 @@ from redis.lock import Lock as RedisLock
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.tasks.indexing.utils import (
|
||||
from onyx.background.celery.tasks.indexing.tasks import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
@@ -73,20 +73,21 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
@@ -134,7 +135,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock # type: ignore
|
||||
sender.primary_worker_lock = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
|
||||
@@ -91,28 +91,6 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
"""
|
||||
|
||||
task_set: set[str] = set()
|
||||
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
task_id = task_dict.get("headers", {}).get("id")
|
||||
if task_id:
|
||||
task_set.add(task_id)
|
||||
|
||||
return task_set
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -20,7 +17,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.documents.models import DeletionAttemptSnapshot
|
||||
@@ -158,25 +154,3 @@ def celery_is_worker_primary(worker: Any) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def httpx_init_vespa_pool(
|
||||
max_keepalive_connections: int,
|
||||
timeout: int = VESPA_REQUEST_TIMEOUT,
|
||||
ssl_cert: str | None = None,
|
||||
ssl_key: str | None = None,
|
||||
) -> None:
|
||||
httpx_cert = None
|
||||
httpx_verify = False
|
||||
if ssl_cert and ssl_key:
|
||||
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
|
||||
httpx_verify = True
|
||||
|
||||
HttpxPool.init_client(
|
||||
name="vespa",
|
||||
cert=httpx_cert,
|
||||
verify=httpx_verify,
|
||||
timeout=timeout,
|
||||
http2=False,
|
||||
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
|
||||
)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Monitoring worker specific settings
|
||||
worker_concurrency = 1 # Single worker is sufficient for monitoring
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -2,259 +2,106 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule = [
|
||||
# cloud specific tasks
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
# remaining tasks are cloud generators for per tenant tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "check-for-prune",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
cloud_tasks_to_schedule.append(
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -10,17 +10,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -33,7 +30,6 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -45,7 +41,7 @@ def check_for_connector_deletion_task(
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
@@ -117,21 +113,11 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# we need to load the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
@@ -178,13 +164,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
raise
|
||||
|
||||
@@ -3,18 +3,14 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
@@ -25,46 +21,31 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -76,9 +57,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
@@ -113,21 +91,15 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
@@ -144,32 +116,14 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating permission sync fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -188,15 +142,13 @@ def try_creating_permissions_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
@@ -221,25 +173,6 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.permissions.set_active()
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
result = app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
@@ -251,12 +184,12 @@ def try_creating_permissions_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
payload_id = payload.celery_task_id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@@ -264,7 +197,7 @@ def try_creating_permissions_sync_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return payload_id
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -285,8 +218,6 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
@@ -348,10 +279,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
@@ -374,17 +302,12 @@ def connector_permission_sync_generator_task(
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
@@ -434,8 +357,6 @@ def update_external_document_permissions_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
@@ -465,330 +386,10 @@ def update_external_document_permissions_task(
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"connector_id={connector_id} "
|
||||
f"doc={doc_id} "
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc_id={doc_id}"
|
||||
)
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
validate_permission_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
reserved_generator_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_permission_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_permission_sync_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
if tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.permissions.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
|
||||
class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
try:
|
||||
self.redis_connector.permissions.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"PermissionSyncCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"Permissions sync payload failed to validate. "
|
||||
"Schema may have been updated."
|
||||
)
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
|
||||
task_logger.info(
|
||||
f"Permissions sync finished: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial,
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -10,7 +9,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@@ -22,12 +20,9 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -38,18 +33,12 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -102,20 +91,15 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
@@ -147,7 +131,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
@@ -156,23 +139,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -215,12 +181,6 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
@@ -234,17 +194,13 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
@@ -271,7 +227,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@@ -279,65 +235,22 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
@@ -372,26 +285,11 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
@@ -400,135 +298,3 @@ def connector_external_group_sync_generator_task(
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,522 +0,0 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
from onyx.db.index_attempt import delete_index_attempt
|
||||
from onyx.db.index_attempt import get_all_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
"""Validates all indexing fences for this tenant ... aka makes sure
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Checks various global settings and past indexing attempts to determine if
|
||||
we should try to start indexing the cc pair / search setting combination.
|
||||
|
||||
Note that tactical checks such as preventing overlap with a currently running task
|
||||
are not handled here.
|
||||
|
||||
Return True if we should try to index, False if not.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexPayload(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
@@ -14,16 +14,8 @@ from onyx.db.models import LLMProvider
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
@@ -35,18 +27,11 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
@@ -54,7 +39,6 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -1,829 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
_CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
_SYNC_START_LATENCY_KEY_FMT = (
|
||||
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
)
|
||||
|
||||
_CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}"
|
||||
_CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}"
|
||||
_SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
_SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
|
||||
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
|
||||
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
|
||||
redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day
|
||||
|
||||
|
||||
def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool:
|
||||
"""Check if a metric has been emitted by checking for existence of Redis key"""
|
||||
return bool(redis_std.exists(key))
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
key: str | None # only required if we need to store that we have emitted this metric
|
||||
name: str
|
||||
value: Any
|
||||
tags: dict[str, str]
|
||||
|
||||
def log(self) -> None:
|
||||
"""Log the metric in a standardized format"""
|
||||
data = {
|
||||
"metric": self.name,
|
||||
"value": self.value,
|
||||
"tags": self.tags,
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
int_value = None
|
||||
string_value = None
|
||||
# NOTE: have to do bool first, since `isinstance(True, int)` is true
|
||||
# e.g. bool is a subclass of int
|
||||
if isinstance(self.value, bool):
|
||||
bool_value = self.value
|
||||
elif isinstance(self.value, int):
|
||||
int_value = self.value
|
||||
elif isinstance(self.value, float):
|
||||
float_value = self.value
|
||||
elif isinstance(self.value, str):
|
||||
string_value = self.value
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid metric value type: {type(self.value)} "
|
||||
f"({self.value}) for metric {self.name}."
|
||||
)
|
||||
return
|
||||
|
||||
# don't send None values over the wire
|
||||
data = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"metric_name": self.name,
|
||||
"float_value": float_value,
|
||||
"int_value": int_value,
|
||||
"string_value": string_value,
|
||||
"bool_value": bool_value,
|
||||
"tags": self.tags,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
task_logger.info(f"Emitting metric: {data}")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.METRIC,
|
||||
data=data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"indexing_queue_length": "indexing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name=name,
|
||||
value=celery_get_queue_length(queue, redis_celery),
|
||||
tags={"queue": name},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _build_connector_start_latency_metric(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempt: IndexAttempt,
|
||||
second_most_recent_attempt: IndexAttempt | None,
|
||||
redis_std: Redis,
|
||||
) -> Metric | None:
|
||||
if not recent_attempt.time_started:
|
||||
return None
|
||||
|
||||
# check if we already emitted a metric for this index attempt
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=recent_attempt.id,
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {recent_attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
return None
|
||||
|
||||
# Connector start latency
|
||||
# first run case - we should start as soon as it's created
|
||||
if not second_most_recent_attempt:
|
||||
desired_start_time = cc_pair.connector.time_created
|
||||
else:
|
||||
if not cc_pair.connector.refresh_freq:
|
||||
task_logger.error(
|
||||
"Found non-initial index attempt for connector "
|
||||
"without refresh_freq. This should never happen."
|
||||
)
|
||||
return None
|
||||
|
||||
desired_start_time = second_most_recent_attempt.time_updated + timedelta(
|
||||
seconds=cc_pair.connector.refresh_freq
|
||||
)
|
||||
|
||||
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
|
||||
|
||||
task_logger.info(
|
||||
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
|
||||
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
|
||||
)
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
|
||||
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_start_latency",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_connector_final_metrics(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempts: list[IndexAttempt],
|
||||
redis_std: Redis,
|
||||
) -> list[Metric]:
|
||||
"""
|
||||
Final metrics for connector index attempts:
|
||||
- Boolean success/fail metric
|
||||
- If success, emit:
|
||||
* duration (seconds)
|
||||
* doc_count
|
||||
"""
|
||||
metrics = []
|
||||
for attempt in recent_attempts:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=attempt.id,
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping final metrics for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id}, already emitted."
|
||||
)
|
||||
continue
|
||||
|
||||
# We only emit final metrics if the attempt is in a terminal state
|
||||
if attempt.status not in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
# Not finished; skip
|
||||
continue
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
|
||||
success = attempt.status == IndexingStatus.SUCCESS
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key, # We'll mark the same key for any final metrics
|
||||
name="connector_run_succeeded",
|
||||
value=success,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
"status": attempt.status.value,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if success:
|
||||
# Make sure we have valid time_started
|
||||
if attempt.time_started and attempt.time_updated:
|
||||
duration_seconds = (
|
||||
attempt.time_updated - attempt.time_started
|
||||
).total_seconds()
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None, # No need for a new key, or you can reuse the same if you prefer
|
||||
name="connector_index_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Index attempt {attempt.id} succeeded but has missing time "
|
||||
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
|
||||
)
|
||||
|
||||
# For doc counts, choose whichever field is more relevant
|
||||
doc_count = attempt.total_docs_indexed or 0
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name="connector_index_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about connector runs from the past hour"""
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all connector credential pairs
|
||||
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
|
||||
# Might be more than one search setting, or just one
|
||||
active_search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
metrics = []
|
||||
|
||||
# If you want to process each cc_pair against each search setting:
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings in active_search_settings_list:
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not recent_attempts:
|
||||
continue
|
||||
|
||||
most_recent_attempt = recent_attempts[0]
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
if one_hour_ago > most_recent_attempt.time_created:
|
||||
continue
|
||||
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id(
|
||||
"connector", str(cc_pair.id), str(most_recent_attempt.id)
|
||||
)
|
||||
|
||||
# Add raw start time metric if available
|
||||
if most_recent_attempt.time_started:
|
||||
start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=most_recent_attempt.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_time_key,
|
||||
name="connector_start_time",
|
||||
value=most_recent_attempt.time_started.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add raw end time metric if available and in terminal state
|
||||
if (
|
||||
most_recent_attempt.status.is_terminal()
|
||||
and most_recent_attempt.time_updated
|
||||
):
|
||||
end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=most_recent_attempt.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=end_time_key,
|
||||
name="connector_end_time",
|
||||
value=most_recent_attempt.time_updated.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
final_metrics = _build_connector_final_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(final_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""
|
||||
Collect metrics for document set and group syncing:
|
||||
- Success/failure status
|
||||
- Start latency (for doc sets / user groups)
|
||||
- Duration & doc count (only if success)
|
||||
- Throughput (docs/min) (only if success)
|
||||
- Raw start/end times for each sync
|
||||
"""
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records that ended in the last hour
|
||||
recent_sync_records = db_session.scalars(
|
||||
select(SyncRecord)
|
||||
.where(SyncRecord.sync_end_time.isnot(None))
|
||||
.where(SyncRecord.sync_end_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_end_time.desc())
|
||||
).all()
|
||||
|
||||
task_logger.info(
|
||||
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
|
||||
)
|
||||
|
||||
metrics = []
|
||||
|
||||
for sync_record in recent_sync_records:
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id("sync_record", str(sync_record.id))
|
||||
|
||||
# Add raw start time metric
|
||||
start_time_key = _SYNC_START_TIME_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_time_key,
|
||||
name="sync_start_time",
|
||||
value=sync_record.sync_start_time.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add raw end time metric if available
|
||||
if sync_record.sync_end_time:
|
||||
end_time_key = _SYNC_END_TIME_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=end_time_key,
|
||||
name="sync_end_time",
|
||||
value=sync_record.sync_end_time.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Emit a SUCCESS/FAIL boolean metric
|
||||
# Use a single Redis key to avoid re-emitting final metrics
|
||||
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, final_metric_key):
|
||||
# Evaluate success
|
||||
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_run_succeeded",
|
||||
value=sync_succeeded,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# If successful, emit additional metrics
|
||||
if sync_succeeded:
|
||||
if sync_record.sync_end_time and sync_record.sync_start_time:
|
||||
duration_seconds = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid times for sync record {sync_record.id}: "
|
||||
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
|
||||
)
|
||||
duration_seconds = None
|
||||
|
||||
doc_count = sync_record.num_docs_synced or 0
|
||||
|
||||
sync_speed = None
|
||||
if duration_seconds and duration_seconds > 0:
|
||||
duration_mins = duration_seconds / 60.0
|
||||
sync_speed = (
|
||||
doc_count / duration_mins if duration_mins > 0 else None
|
||||
)
|
||||
|
||||
# Emit duration, doc count, speed
|
||||
if duration_seconds is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if sync_speed is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
# Emit start latency
|
||||
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Sync record of type {sync_record.sync_type} doesn't have an entity "
|
||||
f"associated with it (id={sync_record.entity_id}). Skipping start latency metric."
|
||||
)
|
||||
|
||||
# Calculate start latency in seconds:
|
||||
# (actual sync start) - (last modified time)
|
||||
if (
|
||||
entity is not None
|
||||
and entity.time_last_modified_by_user
|
||||
and sync_record.sync_start_time
|
||||
):
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Negative start latency for sync record {sync_record.id} "
|
||||
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def build_job_id(
|
||||
job_type: Literal["connector", "sync_record"],
|
||||
primary_id: str,
|
||||
secondary_id: str | None = None,
|
||||
) -> str:
|
||||
if job_type == "connector":
|
||||
if secondary_id is None:
|
||||
raise ValueError(
|
||||
"secondary_id (attempt_id) is required for connector job_type"
|
||||
)
|
||||
return f"connector:{primary_id}:attempt:{secondary_id}"
|
||||
elif job_type == "sync_record":
|
||||
return f"sync_record:{primary_id}"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
- Connector run metrics (start latency, success rate)
|
||||
- Syncing speed metrics
|
||||
- Worker status and task counts
|
||||
"""
|
||||
if tenant_id is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_monitoring: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
|
||||
timeout=_MONITORING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_monitoring.acquire(blocking=False):
|
||||
task_logger.info("Skipping monitoring task because it is already running")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception("Error collecting background process metrics")
|
||||
raise e
|
||||
finally:
|
||||
if lock_monitoring.owned():
|
||||
lock_monitoring.release()
|
||||
|
||||
task_logger.info("Background monitoring task finished")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
|
||||
This check is expected to fail if a cloud alembic migration is currently running
|
||||
across all tenants.
|
||||
|
||||
TODO: have the cloud migration script set an activity signal that this check
|
||||
uses to know it doesn't make sense to run a check at the present time.
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
tenant_to_revision: dict[str, str | None] = {}
|
||||
revision_counts: dict[str, int] = {}
|
||||
out_of_date_tenants: dict[str, str | None] = {}
|
||||
top_revision: str = ""
|
||||
|
||||
try:
|
||||
# map each tenant_id to its revision
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
tenant_to_revision[tenant_id] = result_scalar
|
||||
|
||||
# get the total count of each revision
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
revision_counts[v] = revision_counts.get(v, 0) + 1
|
||||
|
||||
# get the revision with the most counts
|
||||
sorted_revision_counts = sorted(
|
||||
revision_counts.items(), key=lambda item: item[1], reverse=True
|
||||
)
|
||||
|
||||
if len(sorted_revision_counts) == 0:
|
||||
task_logger.error(
|
||||
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
|
||||
)
|
||||
else:
|
||||
top_revision, _ = sorted_revision_counts[0]
|
||||
|
||||
# build a list of out of date tenants
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v == top_revision:
|
||||
continue
|
||||
|
||||
out_of_date_tenants[k] = v
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud alembic check")
|
||||
raise
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error("cloud_check_alembic - Lock not owned on completion")
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
if len(out_of_date_tenants) > 0:
|
||||
task_logger.error(
|
||||
f"Found out of date tenants: "
|
||||
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"revision={top_revision}"
|
||||
)
|
||||
for k, v in islice(out_of_date_tenants.items(), 5):
|
||||
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
|
||||
else:
|
||||
task_logger.info(
|
||||
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
@@ -13,11 +13,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.tasks import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -25,30 +25,21 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due.
|
||||
|
||||
@@ -87,7 +78,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -96,7 +86,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
@@ -113,10 +103,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
@@ -213,14 +200,6 @@ def try_creating_prune_generator_task(
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
except Exception:
|
||||
@@ -252,8 +231,6 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
@@ -367,52 +344,3 @@ def connector_pruning_generator_task(
|
||||
lock.release()
|
||||
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial,
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from tenacity import RetryError
|
||||
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.document import fetch_chunk_count_for_document
|
||||
@@ -25,16 +18,11 @@ from onyx.db.document import get_document_connector_count
|
||||
from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
@@ -75,18 +63,14 @@ def document_by_cc_pair_cleanup_task(
|
||||
"""
|
||||
task_logger.debug(f"Task start: doc={document_id}")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
@@ -156,13 +140,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
@@ -217,78 +199,3 @@ def document_by_cc_pair_cleanup_task(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_beat_task_generator(
|
||||
self: Task,
|
||||
task_name: str,
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# needed in the cloud
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
|
||||
finally:
|
||||
if not lock_beat.owned():
|
||||
task_logger.error(
|
||||
"cloud_beat_task_generator - Lock not owned on completion"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
else:
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
@@ -24,10 +23,6 @@ from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
monitor_ccpair_permissions_taskset,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -38,6 +33,8 @@ from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
@@ -56,29 +53,24 @@ from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
@@ -98,7 +90,6 @@ logger = setup_logger()
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -287,21 +278,11 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
document_set = get_document_set_by_id(
|
||||
db_session=db_session,
|
||||
document_set_id=document_set_id,
|
||||
)
|
||||
document_set = get_document_set_by_id(db_session, document_set_id)
|
||||
if not document_set:
|
||||
return None
|
||||
|
||||
if document_set.is_up_to_date:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@@ -330,13 +311,6 @@ def try_generate_document_set_sync_tasks(
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -358,9 +332,8 @@ def try_generate_user_group_sync_tasks(
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
fetch_user_group = cast(
|
||||
Callable[[Session, int], UserGroup | None],
|
||||
fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"),
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_group"
|
||||
)
|
||||
|
||||
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||
@@ -368,13 +341,6 @@ def try_generate_user_group_sync_tasks(
|
||||
return None
|
||||
|
||||
if usergroup.is_up_to_date:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@@ -402,16 +368,8 @@ def try_generate_user_group_sync_tasks(
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -461,13 +419,6 @@ def monitor_document_set_taskset(
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
@@ -486,13 +437,6 @@ def monitor_document_set_taskset(
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
|
||||
@@ -526,21 +470,10 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -612,29 +545,11 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
@@ -652,6 +567,83 @@ def monitor_connector_deletion_taskset(
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -660,7 +652,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -710,7 +702,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
@@ -756,7 +747,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -766,20 +757,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
@@ -796,20 +773,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
@@ -825,17 +793,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
@@ -895,19 +852,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
@@ -917,82 +872,70 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["documentset"] = time.monotonic() - phase_start
|
||||
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -1019,15 +962,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
@@ -1073,16 +1012,9 @@ def vespa_metadata_sync_task(
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
# r.delete(redis_syncing_key)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -4,10 +4,9 @@ not follow the expected behavior, etc.
|
||||
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing as mp
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.context import SpawnProcess
|
||||
from multiprocessing import Process
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
@@ -47,9 +46,7 @@ def _initializer(
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(
|
||||
pool_size=4, max_overflow=12, pool_recycle=60, pool_pre_ping=True
|
||||
)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
return func(*args, **kwargs)
|
||||
@@ -66,7 +63,7 @@ class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
|
||||
id: int
|
||||
process: Optional["SpawnProcess"] = None
|
||||
process: Optional["Process"] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@@ -134,10 +131,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
# this approach allows us to always "spawn" a new process regardless of
|
||||
# get_start_method's current setting
|
||||
ctx = mp.get_context("spawn")
|
||||
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
@@ -12,7 +11,6 @@ from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
@@ -23,19 +21,16 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
@@ -80,8 +75,7 @@ def _get_connector_runner(
|
||||
# it will never succeed
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=attempt.connector_credential_pair.id,
|
||||
attempt.connector_credential_pair.id, db_session
|
||||
)
|
||||
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
|
||||
update_connector_credential_pair(
|
||||
@@ -102,17 +96,10 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if cleaned_doc.title and "\x00" in cleaned_doc.title:
|
||||
logger.warning(
|
||||
f"NUL characters found in document title: {cleaned_doc.title}"
|
||||
)
|
||||
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
@@ -128,9 +115,6 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
@@ -140,21 +124,9 @@ class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
class RunIndexingContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
@@ -168,77 +140,61 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
)
|
||||
|
||||
if index_attempt_start.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
# search_settings = index_attempt_start.search_settings
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
ctx = RunIndexingContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
source=db_connector.source,
|
||||
earliest_index_time=(
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
),
|
||||
from_beginning=index_attempt_start.from_beginning,
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary=(
|
||||
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
last_successful_index_time = (
|
||||
ctx.earliest_index_time
|
||||
if ctx.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
earliest_index=ctx.earliest_index_time,
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
search_settings = index_attempt.search_settings
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
callback=callback,
|
||||
)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt_id,
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
|
||||
)
|
||||
|
||||
last_successful_index_time = (
|
||||
earliest_index_time
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = OnyxTracer()
|
||||
@@ -246,8 +202,8 @@ def _run_indexing(
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
|
||||
batch_num = 0
|
||||
@@ -263,31 +219,21 @@ def _run_indexing(
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
cc_pair_loop: ConnectorCredentialPair | None = None
|
||||
index_attempt_loop: IndexAttempt | None = None
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_loop_start = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop_start:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt_loop_start,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
tracer_counter = 0
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
@@ -302,38 +248,24 @@ def _run_indexing(
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
db_session.refresh(db_cc_pair)
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and search_settings.status != IndexModelStatus.FUTURE
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
if (
|
||||
(
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt.status}"
|
||||
)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
|
||||
@@ -357,15 +289,16 @@ def _run_indexing(
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
index_pipeline_result = indexing_pipeline(
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += index_pipeline_result.new_docs
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -374,19 +307,18 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
@@ -399,35 +331,33 @@ def _run_indexing(
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
|
||||
run_end_dt = window_end
|
||||
if ctx.is_primary:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
@@ -442,29 +372,23 @@ def _run_indexing(
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or (
|
||||
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
|
||||
)
|
||||
or (
|
||||
index_attempt_loop is not None
|
||||
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
|
||||
)
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
@@ -487,58 +411,56 @@ def _run_indexing(
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
@@ -558,35 +480,27 @@ def run_indexing_entrypoint(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
connector_config = (
|
||||
attempt.connector_credential_pair.connector.connector_specific_config
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||
|
||||
@@ -12,10 +12,10 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.build import default_build_system_message
|
||||
from onyx.chat.prompt_builder.build import default_build_user_message
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
@@ -212,6 +212,19 @@ class Answer:
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
@@ -252,13 +265,11 @@ class Answer:
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
single_message_history=self.single_message_history,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
raw_user_query=self.question,
|
||||
raw_user_uploaded_files=self.latest_query_files or [],
|
||||
single_message_history=self.single_message_history,
|
||||
raw_user_text=self.question,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.prompts import get_prompts_by_ids
|
||||
from onyx.db.persona import get_prompts_by_ids
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain_core.messages import BaseMessage
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -260,8 +261,13 @@ class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
@@ -270,6 +276,20 @@ class AnswerStyleConfig(BaseModel):
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
|
||||
@@ -254,7 +254,6 @@ def _get_force_search_settings(
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
new_msg_req.query_override is not None,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
@@ -303,11 +302,6 @@ def stream_chat_message_objects(
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
# a string which represents the history of a conversation. Used in cases like
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages.
|
||||
# NOTE: is not stored in the database at all.
|
||||
single_message_history: str | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -426,7 +420,9 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
@@ -498,6 +494,14 @@ def stream_chat_message_objects(
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -703,7 +707,6 @@ def stream_chat_message_objects(
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
|
||||
@@ -15,12 +15,10 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -32,45 +30,30 @@ def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
if not system_prompt:
|
||||
return None
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
single_message_history: str | None = None,
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
) -> HumanMessage:
|
||||
history_block = (
|
||||
HISTORY_BLOCK.format(history_str=single_message_history)
|
||||
if single_message_history
|
||||
else ""
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
history_block=history_block,
|
||||
task_prompt=prompt_config.task_prompt,
|
||||
user_query=user_query,
|
||||
task_prompt=prompt_config.task_prompt, user_query=user_query
|
||||
)
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
return user_msg
|
||||
|
||||
@@ -81,8 +64,7 @@ class AnswerPromptBuilder:
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
raw_user_query: str,
|
||||
raw_user_uploaded_files: list[InMemoryChatFile],
|
||||
raw_user_text: str,
|
||||
single_message_history: str | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
@@ -91,7 +73,6 @@ class AnswerPromptBuilder:
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_config = llm_config
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -100,29 +81,21 @@ class AnswerPromptBuilder:
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(
|
||||
message_history,
|
||||
exclude_images=not model_supports_image_input(
|
||||
self.llm_config.model_name,
|
||||
self.llm_config.model_provider,
|
||||
),
|
||||
)
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
# used for building a new prompt after a tool-call
|
||||
self.raw_user_query = raw_user_query
|
||||
self.raw_user_uploaded_files = raw_user_uploaded_files
|
||||
self.single_message_history = single_message_history
|
||||
self.raw_user_message = raw_user_text
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
@@ -1,13 +1,12 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.db.persona import get_default_prompt__read_only
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -21,9 +20,9 @@ from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
from onyx.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
@@ -98,12 +97,11 @@ def compute_max_document_tokens(
|
||||
|
||||
|
||||
def compute_max_document_tokens_for_persona(
|
||||
db_session: Session,
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
@@ -127,11 +125,10 @@ def build_citations_system_message(
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.include_citations:
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt, prompt_config, add_additional_info_if_no_tag=True
|
||||
)
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
@@ -147,7 +144,9 @@ def build_citations_user_message(
|
||||
)
|
||||
|
||||
history_block = (
|
||||
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
|
||||
HISTORY_BLOCK.format(history_str=history_message) + "\n"
|
||||
if history_message
|
||||
else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
@@ -39,11 +39,10 @@ def _build_strong_llm_quotes_prompt(
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
full_prompt, prompt, add_additional_info_if_no_tag=True
|
||||
)
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=tag_handled_prompt)
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
|
||||
@@ -7,11 +7,30 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
exclude_images: bool = False,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
@@ -19,9 +38,7 @@ def translate_onyx_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@@ -35,12 +52,9 @@ def translate_onyx_msg_to_langchain(
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
exclude_images: bool = False,
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_onyx_msg_to_langchain(msg, exclude_images)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.message import build_tool_message
|
||||
@@ -62,7 +62,7 @@ class ToolResponseHandler:
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
@@ -76,7 +76,7 @@ class ToolResponseHandler:
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -95,7 +95,7 @@ class ToolResponseHandler:
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -56,12 +55,12 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value)
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS")
|
||||
or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS")
|
||||
or 86400 * 7
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
@@ -93,12 +92,6 @@ OAUTH_CLIENT_SECRET = (
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
|
||||
# By default, this is set to match the Redis expiry time for consistency.
|
||||
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -200,8 +193,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# this assumes that other redis settings remain the same as the primary
|
||||
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
@@ -478,12 +469,6 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
|
||||
# Enable multi-threaded embedding model calls for parallel processing
|
||||
# Note: only applies for API-based embedding models
|
||||
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
|
||||
)
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
@@ -617,8 +602,3 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
# Set to true to mock LLM responses for testing purposes
|
||||
MOCK_LLM_RESPONSE = (
|
||||
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
@@ -79,8 +78,6 @@ KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
# might be worth separating this timeout into separate timeouts for each situation
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
@@ -200,7 +197,6 @@ class SessionType(str, Enum):
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
MIXED = "mixed" # User likes some answers and dislikes other, used for chat session metrics
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
@@ -264,9 +260,6 @@ class OnyxCeleryQueues:
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@@ -281,7 +274,6 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
@@ -294,14 +286,9 @@ class OnyxRedisLocks:
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
|
||||
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
|
||||
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
|
||||
|
||||
|
||||
class OnyxCeleryPriority(int, Enum):
|
||||
@@ -312,19 +299,7 @@ class OnyxCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
# a prefix used to distinguish system wide tasks in the cloud
|
||||
ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
|
||||
|
||||
# the tenant id we use for system level redis operations
|
||||
ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
|
||||
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
@@ -332,10 +307,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import contextvars
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
@@ -24,9 +20,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# These field types are considered metadata by default when
|
||||
# treat_all_non_attachment_fields_as_metadata is False
|
||||
DEFAULT_METADATA_FIELD_TYPES = {
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
@@ -64,42 +60,21 @@ class AirtableConnector(LoadConnector):
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self._airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
treat_all_non_attachment_fields_as_metadata
|
||||
)
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
@property
|
||||
def airtable_client(self) -> AirtableApi:
|
||||
if not self._airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
def _extract_field_values(
|
||||
self,
|
||||
field_id: str,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) + links from a field regardless of its type.
|
||||
Attachments are represented as multiple sections, and therefore
|
||||
returned as a list of tuples (value, link).
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
@@ -110,11 +85,8 @@ class AirtableConnector(LoadConnector):
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
# default link to use for non-attachment fields
|
||||
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[tuple[str, str]] = []
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
@@ -127,37 +99,16 @@ class AirtableConnector(LoadConnector):
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str, record_id: str) -> bytes | None:
|
||||
try:
|
||||
attachment_response = requests.get(url)
|
||||
attachment_response.raise_for_status()
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 410:
|
||||
logger.info(f"Refreshing attachment for {filename}")
|
||||
# Re-fetch the record to get a fresh URL
|
||||
refreshed_record = self.airtable_client.table(
|
||||
base_id, table_id
|
||||
).get(record_id)
|
||||
for refreshed_attachment in refreshed_record["fields"][
|
||||
field_name
|
||||
]:
|
||||
if refreshed_attachment.get("filename") == filename:
|
||||
new_url = refreshed_attachment.get("url")
|
||||
if new_url:
|
||||
attachment_response = requests.get(new_url)
|
||||
attachment_response.raise_for_status()
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to refresh attachment for {filename}")
|
||||
|
||||
raise
|
||||
|
||||
attachment_content = get_attachment_with_retry(url, record_id)
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_id = attachment["id"]
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
@@ -165,20 +116,7 @@ class AirtableConnector(LoadConnector):
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
# slightly nicer loading experience if we can specify the view ID
|
||||
if view_id:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
else:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
attachment_texts.append(
|
||||
(f"{filename}:\n{attachment_text}", attachment_link)
|
||||
)
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
@@ -193,31 +131,23 @@ class AirtableConnector(LoadConnector):
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [(" ".join(combined) if combined else str(field_info), default_link)]
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [(item, default_link) for item in field_info]
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [(str(field_info), default_link)]
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata.
|
||||
|
||||
When treat_all_non_attachment_fields_as_metadata is True, all fields except
|
||||
attachments are treated as metadata. Otherwise, only fields with types listed
|
||||
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
|
||||
if self.treat_all_non_attachment_fields_as_metadata:
|
||||
return field_type.lower() != "multipleattachments"
|
||||
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_id: str,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
@@ -235,22 +165,12 @@ class AirtableConnector(LoadConnector):
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_value_and_links = self._extract_field_values(
|
||||
field_id=field_id,
|
||||
field_name=field_name,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=self.base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
if len(field_value_and_links) == 0:
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
field_values = [value for value, _ in field_value_and_links]
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
@@ -258,7 +178,7 @@ class AirtableConnector(LoadConnector):
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
@@ -266,7 +186,7 @@ class AirtableConnector(LoadConnector):
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text, link in field_value_and_links
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
@@ -275,7 +195,7 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document | None:
|
||||
) -> Document:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
@@ -299,35 +219,23 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
view_id = table_schema.views[0].id if table_schema.views else None
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
logger.debug(
|
||||
f"Processing field '{field_name}' of type '{field_type}' "
|
||||
f"for record '{record_id}'."
|
||||
)
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_id=field_schema.id,
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
if not sections:
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
@@ -364,47 +272,18 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents: list[Document] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record: dict[Future, RecordDict] = {}
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
future_to_record[
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
self._process_record,
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
] = record
|
||||
|
||||
# Wait for all tasks in this batch to complete
|
||||
for future in as_completed(future_to_record):
|
||||
record = future_to_record[future]
|
||||
try:
|
||||
document = future.result()
|
||||
if document:
|
||||
record_documents.append(document)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
||||
@@ -232,29 +232,20 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
}
|
||||
|
||||
# Get labels
|
||||
label_dicts = (
|
||||
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
|
||||
)
|
||||
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
|
||||
label_dicts = confluence_object["metadata"]["labels"]["results"]
|
||||
page_labels = [label["name"] for label in label_dicts]
|
||||
if page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
# Get last modified and author email
|
||||
version_dict = confluence_object.get("version", {})
|
||||
last_modified = (
|
||||
datetime_from_string(version_dict.get("when"))
|
||||
if version_dict.get("when")
|
||||
else None
|
||||
)
|
||||
author_email = version_dict.get("by", {}).get("email")
|
||||
|
||||
title = confluence_object.get("title", "Untitled Document")
|
||||
last_modified = datetime_from_string(confluence_object["version"]["when"])
|
||||
author_email = confluence_object["version"].get("by", {}).get("email")
|
||||
|
||||
return Document(
|
||||
id=object_url,
|
||||
sections=[Section(link=object_url, text=object_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=title,
|
||||
semantic_identifier=confluence_object["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author_email)] if author_email else None
|
||||
|
||||
@@ -135,6 +135,32 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -337,9 +363,6 @@ class OnyxConfluence(Confluence):
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
TODO: Make this call these endpoints for newer confluence versions:
|
||||
- /rest/api/space/{spaceKey}/permissions
|
||||
- /rest/api/space/{spaceKey}/permissions/anonymous
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
@@ -358,32 +381,6 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
return response.get("result", [])
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
@@ -46,17 +45,7 @@ class ConnectorRunner:
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
yield from self.doc_batch_generator
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
|
||||
@@ -30,14 +30,13 @@ _FIREFLIES_API_QUERY = """
|
||||
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
|
||||
id
|
||||
title
|
||||
organizer_email
|
||||
host_email
|
||||
participants
|
||||
date
|
||||
transcript_url
|
||||
sentences {
|
||||
text
|
||||
speaker_name
|
||||
start_time
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -45,37 +44,16 @@ _FIREFLIES_API_QUERY = """
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections: List[Section] = []
|
||||
current_speaker_name = None
|
||||
current_link = ""
|
||||
current_text = ""
|
||||
|
||||
if transcript["sentences"] is None:
|
||||
meeting_text = ""
|
||||
sentences = transcript.get("sentences", [])
|
||||
if sentences:
|
||||
for sentence in sentences:
|
||||
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
|
||||
meeting_text += ": " + sentence.get("text", "") + "\n\n"
|
||||
else:
|
||||
return None
|
||||
|
||||
for sentence in transcript["sentences"]:
|
||||
if sentence["speaker_name"] != current_speaker_name:
|
||||
if current_speaker_name is not None:
|
||||
sections.append(
|
||||
Section(
|
||||
link=current_link,
|
||||
text=current_text.strip(),
|
||||
)
|
||||
)
|
||||
current_speaker_name = sentence.get("speaker_name") or "Unknown Speaker"
|
||||
current_link = f"{transcript['transcript_url']}?t={sentence['start_time']}"
|
||||
current_text = f"{current_speaker_name}: "
|
||||
|
||||
cleaned_text = sentence["text"].replace("\xa0", " ")
|
||||
current_text += f"{cleaned_text} "
|
||||
|
||||
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
|
||||
sections.append(
|
||||
Section(
|
||||
link=current_link,
|
||||
text=current_text.strip(),
|
||||
)
|
||||
)
|
||||
meeting_link = transcript["transcript_url"]
|
||||
|
||||
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
|
||||
|
||||
@@ -84,22 +62,27 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
meeting_date_unix = transcript["date"]
|
||||
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
|
||||
|
||||
meeting_organizer_email = transcript["organizer_email"]
|
||||
organizer_email_user_info = [BasicExpertInfo(email=meeting_organizer_email)]
|
||||
meeting_host_email = transcript["host_email"]
|
||||
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
|
||||
|
||||
meeting_participants_email_list = []
|
||||
for participant in transcript.get("participants", []):
|
||||
if participant != meeting_organizer_email and participant:
|
||||
if participant != meeting_host_email and participant:
|
||||
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
|
||||
|
||||
return Document(
|
||||
id=fireflies_id,
|
||||
sections=sections,
|
||||
sections=[
|
||||
Section(
|
||||
link=meeting_link,
|
||||
text=meeting_text,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
doc_updated_at=meeting_date,
|
||||
primary_owners=organizer_email_user_info,
|
||||
primary_owners=host_email_user_info,
|
||||
secondary_owners=meeting_participants_email_list,
|
||||
)
|
||||
|
||||
|
||||
@@ -258,7 +258,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
user_emails.append(email)
|
||||
return user_emails
|
||||
|
||||
def get_all_drive_ids(self) -> set[str]:
|
||||
def _get_all_drive_ids(self) -> set[str]:
|
||||
primary_drive_service = get_drive_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
@@ -353,7 +353,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
|
||||
all_drive_ids: set[str] = self.get_all_drive_ids()
|
||||
all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
@@ -437,7 +437,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# If all 3 are true, we already yielded from get_all_files_for_oauth
|
||||
return
|
||||
|
||||
all_drive_ids = self.get_all_drive_ids()
|
||||
all_drive_ids = self._get_all_drive_ids()
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user