mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
1 Commits
error_supp
...
pr-3208
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2357d9234f |
7
.github/pull_request_template.md
vendored
7
.github/pull_request_template.md
vendored
@@ -1,14 +1,11 @@
|
||||
## Description
|
||||
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
@@ -67,7 +67,6 @@ jobs:
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -60,8 +60,6 @@ jobs:
|
||||
push: true
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
2
.github/workflows/pr-chromatic-tests.yml
vendored
2
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -8,8 +8,6 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.17.0
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,6 +37,22 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -46,7 +62,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
29
.github/workflows/pr-linear-check.yml
vendored
29
.github/workflows/pr-linear-check.yml
vendored
@@ -1,29 +0,0 @@
|
||||
name: Ensure PR references Linear
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
linear-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Looking for a checked override: "[x] Override Linear Check"
|
||||
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
|
||||
echo "Override box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Otherwise, fail the run
|
||||
echo "No Linear link or override found in the PR description."
|
||||
exit 1
|
||||
@@ -39,12 +39,6 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
||||
@@ -32,7 +32,6 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -1,72 +1,30 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
},
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
return base_cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -4,20 +4,6 @@ import os
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
# are requested from the OIDC provider. Currently used when passing
|
||||
# over access tokens to tool calls and the tool needs more scopes
|
||||
OIDC_SCOPE_OVERRIDE: list[str] | None = None
|
||||
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
|
||||
|
||||
if _OIDC_SCOPE_OVERRIDE:
|
||||
try:
|
||||
OIDC_SCOPE_OVERRIDE = [
|
||||
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
|
||||
|
||||
|
||||
@@ -98,9 +98,10 @@ def get_page_of_chat_sessions(
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
|
||||
subquery = (
|
||||
select(ChatSession.id)
|
||||
select(ChatSession.id, ChatSession.time_created)
|
||||
.filter(*conditions)
|
||||
.order_by(desc(ChatSession.time_created), ChatSession.id)
|
||||
.order_by(ChatSession.id, desc(ChatSession.time_created))
|
||||
.distinct(ChatSession.id)
|
||||
.limit(page_size)
|
||||
.offset(page_num * page_size)
|
||||
.subquery()
|
||||
@@ -117,11 +118,7 @@ def get_page_of_chat_sessions(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
desc(ChatSession.time_created),
|
||||
ChatSession.id,
|
||||
asc(ChatMessage.id), # Ensure chronological message order
|
||||
)
|
||||
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -258,7 +257,6 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -267,12 +265,6 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -342,7 +334,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -367,12 +359,6 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -381,5 +367,4 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -45,12 +44,6 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
@@ -129,7 +131,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -147,12 +149,6 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -15,7 +14,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -136,7 +127,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
),
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
|
||||
@@ -111,7 +111,6 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie("fastapiusersauth")
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
||||
@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
@@ -42,17 +42,8 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
|
||||
class AuthBackend(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
@@ -33,8 +33,6 @@ from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -54,15 +52,13 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -78,7 +74,6 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
@@ -87,7 +82,6 @@ from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -215,7 +209,7 @@ def verify_email_domain(email: str) -> None:
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
async def create(
|
||||
@@ -245,8 +239,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
@@ -265,16 +261,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdateWithRole(
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
@@ -282,6 +278,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
@@ -583,14 +580,6 @@ def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
@@ -599,7 +588,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
@@ -648,16 +637,9 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
elif AUTH_BACKEND == AuthBackend.POSTGRES:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
@@ -23,8 +23,8 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -198,8 +198,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
@@ -281,6 +280,51 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
@@ -318,8 +362,6 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
HttpxPool.close_all()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
@@ -468,13 +510,3 @@ def reset_tenant_id(
|
||||
) -> None:
|
||||
"""Signal handler to reset tenant ID in context var after task ends."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -7,6 +8,7 @@ from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
@@ -79,7 +81,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
cloud_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
"kwargs": {},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
@@ -130,25 +132,21 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
current_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
task_name = cast(str, task_name)
|
||||
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
continue
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
if "_" in task_name:
|
||||
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
current_tenants.add(task_name.split("_")[-1])
|
||||
logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in current_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -68,7 +68,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -10,10 +10,6 @@ from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -58,27 +54,16 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -86,7 +86,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
|
||||
@@ -91,28 +91,6 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
"""
|
||||
|
||||
task_set: set[str] = set()
|
||||
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
task_id = task_dict.get("headers", {}).get("id")
|
||||
if task_id:
|
||||
task_set.add(task_id)
|
||||
|
||||
return task_set
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -20,7 +17,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.documents.models import DeletionAttemptSnapshot
|
||||
@@ -158,25 +154,3 @@ def celery_is_worker_primary(worker: Any) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def httpx_init_vespa_pool(
|
||||
max_keepalive_connections: int,
|
||||
timeout: int = VESPA_REQUEST_TIMEOUT,
|
||||
ssl_cert: str | None = None,
|
||||
ssl_key: str | None = None,
|
||||
) -> None:
|
||||
httpx_cert = None
|
||||
httpx_verify = False
|
||||
if ssl_cert and ssl_key:
|
||||
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
|
||||
httpx_verify = True
|
||||
|
||||
HttpxPool.init_client(
|
||||
name="vespa",
|
||||
cert=httpx_cert,
|
||||
verify=httpx_verify,
|
||||
timeout=timeout,
|
||||
http2=False,
|
||||
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
|
||||
)
|
||||
|
||||
@@ -16,241 +16,125 @@ from shared_configs.configs import MULTI_TENANT
|
||||
# it's only important that they run relatively regularly
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
# remaining tasks are cloud generators for per tenant tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
cloud_tasks_to_schedule.append(
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
|
||||
@@ -33,7 +33,6 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -140,6 +139,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
try:
|
||||
@@ -178,13 +184,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
raise
|
||||
|
||||
@@ -3,18 +3,14 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
@@ -25,10 +21,6 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -39,32 +31,21 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -76,9 +57,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
@@ -113,17 +91,11 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -144,32 +116,14 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating permission sync fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -188,15 +142,13 @@ def try_creating_permissions_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
@@ -221,25 +173,6 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.permissions.set_active()
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
result = app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
@@ -251,12 +184,12 @@ def try_creating_permissions_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
payload_id = payload.celery_task_id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@@ -264,7 +197,7 @@ def try_creating_permissions_sync_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return payload_id
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -285,8 +218,6 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
@@ -374,17 +305,12 @@ def connector_permission_sync_generator_task(
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
@@ -434,8 +360,6 @@ def update_external_document_permissions_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
@@ -465,330 +389,12 @@ def update_external_document_permissions_task(
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"connector_id={connector_id} "
|
||||
f"doc={doc_id} "
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc_id={doc_id}"
|
||||
logger.exception(
|
||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
validate_permission_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
reserved_generator_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_permission_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_permission_sync_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
if tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.permissions.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
|
||||
class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
try:
|
||||
self.redis_connector.permissions.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"PermissionSyncCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"Permissions sync payload failed to validate. "
|
||||
"Schema may have been updated."
|
||||
)
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
|
||||
task_logger.info(
|
||||
f"Permissions sync finished: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial,
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -10,7 +9,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@@ -22,12 +20,9 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -38,18 +33,12 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -102,17 +91,12 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
@@ -147,7 +131,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
@@ -156,23 +139,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -215,12 +181,6 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
@@ -234,17 +194,13 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
@@ -271,7 +227,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@@ -279,59 +235,19 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -372,26 +288,11 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
@@ -400,135 +301,3 @@ def connector_external_group_sync_generator_task(
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
@@ -15,7 +15,7 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.indexing.utils import _should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
@@ -23,32 +23,32 @@ from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_ta
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -68,12 +68,15 @@ logger = setup_logger()
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
debug_tenants = {
|
||||
"tenant_i-043470d740845ec56",
|
||||
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
|
||||
}
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
@@ -120,18 +123,48 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair search settings lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_connector_credential_pair_from_id: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -139,10 +172,28 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_last_attempt_for_cc_pair: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair should index: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
@@ -175,6 +226,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
@@ -195,6 +255,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
@@ -204,7 +282,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
db_session, redis_client
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing after get unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced attempt id lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -222,6 +317,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing validate fences lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
@@ -230,7 +334,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, redis_client_replica, redis_client_celery, lock_beat
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
@@ -304,14 +408,6 @@ def connector_indexing_task(
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
@@ -578,9 +674,6 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
@@ -687,10 +780,67 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
continue
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_check_for_indexing(self: Task) -> bool | None:
|
||||
"""a lightweight task used to kick off individual check tasks for each tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CLOUD_CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=BEAT_EXPIRES_DEFAULT,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud indexing check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error("cloud_check_for_indexing - Lock not owned on completion")
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_check_for_indexing finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -291,20 +291,17 @@ def validate_indexing_fence(
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
r_replica: Redis,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
"""Validates all indexing fences for this tenant ... aka makes sure
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -14,16 +14,8 @@ from onyx.db.models import LLMProvider
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
@@ -35,18 +27,11 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
@@ -54,7 +39,6 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -13,34 +10,25 @@ from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
@@ -53,17 +41,6 @@ _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
_SYNC_START_LATENCY_KEY_FMT = (
|
||||
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
)
|
||||
|
||||
_CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}"
|
||||
_CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}"
|
||||
_SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
_SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
|
||||
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
|
||||
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
|
||||
@@ -126,7 +103,6 @@ class Metric(BaseModel):
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
task_logger.info(f"Emitting metric: {data}")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.METRIC,
|
||||
data=data,
|
||||
@@ -201,111 +177,45 @@ def _build_connector_start_latency_metric(
|
||||
|
||||
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
|
||||
|
||||
task_logger.info(
|
||||
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
|
||||
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
|
||||
)
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
|
||||
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_start_latency",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
tags={},
|
||||
)
|
||||
|
||||
|
||||
def _build_connector_final_metrics(
|
||||
def _build_run_success_metrics(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempts: list[IndexAttempt],
|
||||
redis_std: Redis,
|
||||
) -> list[Metric]:
|
||||
"""
|
||||
Final metrics for connector index attempts:
|
||||
- Boolean success/fail metric
|
||||
- If success, emit:
|
||||
* duration (seconds)
|
||||
* doc_count
|
||||
"""
|
||||
metrics = []
|
||||
for attempt in recent_attempts:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=attempt.id,
|
||||
)
|
||||
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping final metrics for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id}, already emitted."
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# We only emit final metrics if the attempt is in a terminal state
|
||||
if attempt.status not in [
|
||||
if attempt.status in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
# Not finished; skip
|
||||
continue
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
|
||||
success = attempt.status == IndexingStatus.SUCCESS
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key, # We'll mark the same key for any final metrics
|
||||
name="connector_run_succeeded",
|
||||
value=success,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
"status": attempt.status.value,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if success:
|
||||
# Make sure we have valid time_started
|
||||
if attempt.time_started and attempt.time_updated:
|
||||
duration_seconds = (
|
||||
attempt.time_updated - attempt.time_started
|
||||
).total_seconds()
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None, # No need for a new key, or you can reuse the same if you prefer
|
||||
name="connector_index_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Index attempt {attempt.id} succeeded but has missing time "
|
||||
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
|
||||
)
|
||||
|
||||
# For doc counts, choose whichever field is more relevant
|
||||
doc_count = attempt.total_docs_indexed or 0
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name="connector_index_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -314,337 +224,178 @@ def _build_connector_final_metrics(
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about connector runs from the past hour"""
|
||||
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all connector credential pairs
|
||||
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
|
||||
# Might be more than one search setting, or just one
|
||||
active_search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
metrics = []
|
||||
|
||||
# If you want to process each cc_pair against each search setting:
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings in active_search_settings_list:
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
# Get all attempts in the last hour
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.time_created >= one_hour_ago,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.all()
|
||||
)
|
||||
most_recent_attempt = recent_attempts[0] if recent_attempts else None
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
if not recent_attempts:
|
||||
continue
|
||||
# if no metric to emit, skip
|
||||
if most_recent_attempt is None:
|
||||
continue
|
||||
|
||||
most_recent_attempt = recent_attempts[0]
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
if one_hour_ago > most_recent_attempt.time_created:
|
||||
continue
|
||||
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id(
|
||||
"connector", str(cc_pair.id), str(most_recent_attempt.id)
|
||||
)
|
||||
|
||||
# Add raw start time metric if available
|
||||
if most_recent_attempt.time_started:
|
||||
start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=most_recent_attempt.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_time_key,
|
||||
name="connector_start_time",
|
||||
value=most_recent_attempt.time_started.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add raw end time metric if available and in terminal state
|
||||
if (
|
||||
most_recent_attempt.status.is_terminal()
|
||||
and most_recent_attempt.time_updated
|
||||
):
|
||||
end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=most_recent_attempt.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=end_time_key,
|
||||
name="connector_end_time",
|
||||
value=most_recent_attempt.time_updated.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
final_metrics = _build_connector_final_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(final_metrics)
|
||||
# Connector run success/failure
|
||||
run_success_metrics = _build_run_success_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(run_success_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""
|
||||
Collect metrics for document set and group syncing:
|
||||
- Success/failure status
|
||||
- Start latency (for doc sets / user groups)
|
||||
- Duration & doc count (only if success)
|
||||
- Throughput (docs/min) (only if success)
|
||||
- Raw start/end times for each sync
|
||||
"""
|
||||
"""Collect metrics about document set and group syncing speed"""
|
||||
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records that ended in the last hour
|
||||
# Get all sync records from the last hour
|
||||
recent_sync_records = db_session.scalars(
|
||||
select(SyncRecord)
|
||||
.where(SyncRecord.sync_end_time.isnot(None))
|
||||
.where(SyncRecord.sync_end_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_end_time.desc())
|
||||
.where(SyncRecord.sync_start_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_start_time.desc())
|
||||
).all()
|
||||
|
||||
task_logger.info(
|
||||
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
|
||||
)
|
||||
|
||||
metrics = []
|
||||
|
||||
for sync_record in recent_sync_records:
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id("sync_record", str(sync_record.id))
|
||||
# Skip if no end time (sync still in progress)
|
||||
if not sync_record.sync_end_time:
|
||||
continue
|
||||
|
||||
# Add raw start time metric
|
||||
start_time_key = _SYNC_START_TIME_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
# Check if we already emitted a metric for this sync record
|
||||
metric_key = (
|
||||
f"sync_speed:{sync_record.sync_type}:"
|
||||
f"{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.debug(
|
||||
f"Skipping metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate sync duration in minutes
|
||||
sync_duration_mins = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds() / 60.0
|
||||
|
||||
# Calculate sync speed (docs/min) - avoid division by zero
|
||||
sync_speed = (
|
||||
sync_record.num_docs_synced / sync_duration_mins
|
||||
if sync_duration_mins > 0
|
||||
else None
|
||||
)
|
||||
|
||||
if sync_speed is None:
|
||||
task_logger.error(
|
||||
"Something went wrong with sync speed calculation. "
|
||||
f"Sync record: {sync_record.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_time_key,
|
||||
name="sync_start_time",
|
||||
value=sync_record.sync_start_time.timestamp(),
|
||||
key=metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add sync start latency metric
|
||||
start_latency_key = (
|
||||
f"sync_start_latency:{sync_record.sync_type}"
|
||||
f":{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
task_logger.debug(
|
||||
f"Skipping start latency metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Skip other sync types
|
||||
task_logger.debug(
|
||||
f"Skipping sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} "
|
||||
f"and id {sync_record.entity_id} "
|
||||
"because it is not a document set or user group"
|
||||
)
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Could not find entity for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate start latency in seconds
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Start latency is negative for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
|
||||
"This is likely because the entity was updated between the time the "
|
||||
"time the sync finished and this job ran. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add raw end time metric if available
|
||||
if sync_record.sync_end_time:
|
||||
end_time_key = _SYNC_END_TIME_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=end_time_key,
|
||||
name="sync_end_time",
|
||||
value=sync_record.sync_end_time.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Emit a SUCCESS/FAIL boolean metric
|
||||
# Use a single Redis key to avoid re-emitting final metrics
|
||||
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, final_metric_key):
|
||||
# Evaluate success
|
||||
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_run_succeeded",
|
||||
value=sync_succeeded,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# If successful, emit additional metrics
|
||||
if sync_succeeded:
|
||||
if sync_record.sync_end_time and sync_record.sync_start_time:
|
||||
duration_seconds = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid times for sync record {sync_record.id}: "
|
||||
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
|
||||
)
|
||||
duration_seconds = None
|
||||
|
||||
doc_count = sync_record.num_docs_synced or 0
|
||||
|
||||
sync_speed = None
|
||||
if duration_seconds and duration_seconds > 0:
|
||||
duration_mins = duration_seconds / 60.0
|
||||
sync_speed = (
|
||||
doc_count / duration_mins if duration_mins > 0 else None
|
||||
)
|
||||
|
||||
# Emit duration, doc count, speed
|
||||
if duration_seconds is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if sync_speed is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
# Emit start latency
|
||||
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Sync record of type {sync_record.sync_type} doesn't have an entity "
|
||||
f"associated with it (id={sync_record.entity_id}). Skipping start latency metric."
|
||||
)
|
||||
|
||||
# Calculate start latency in seconds:
|
||||
# (actual sync start) - (last modified time)
|
||||
if (
|
||||
entity is not None
|
||||
and entity.time_last_modified_by_user
|
||||
and sync_record.sync_start_time
|
||||
):
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Negative start latency for sync record {sync_record.id} "
|
||||
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def build_job_id(
|
||||
job_type: Literal["connector", "sync_record"],
|
||||
primary_id: str,
|
||||
secondary_id: str | None = None,
|
||||
) -> str:
|
||||
if job_type == "connector":
|
||||
if secondary_id is None:
|
||||
raise ValueError(
|
||||
"secondary_id (attempt_id) is required for connector job_type"
|
||||
)
|
||||
return f"connector:{primary_id}:attempt:{secondary_id}"
|
||||
elif job_type == "sync_record":
|
||||
return f"sync_record:{primary_id}"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
@@ -658,9 +409,6 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
- Syncing speed metrics
|
||||
- Worker status and task counts
|
||||
"""
|
||||
if tenant_id is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -685,20 +433,14 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
if metric.key:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
@@ -714,116 +456,3 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_monitoring.release()
|
||||
|
||||
task_logger.info("Background monitoring task finished")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
|
||||
This check is expected to fail if a cloud alembic migration is currently running
|
||||
across all tenants.
|
||||
|
||||
TODO: have the cloud migration script set an activity signal that this check
|
||||
uses to know it doesn't make sense to run a check at the present time.
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
tenant_to_revision: dict[str, str | None] = {}
|
||||
revision_counts: dict[str, int] = {}
|
||||
out_of_date_tenants: dict[str, str | None] = {}
|
||||
top_revision: str = ""
|
||||
|
||||
try:
|
||||
# map each tenant_id to its revision
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
tenant_to_revision[tenant_id] = result_scalar
|
||||
|
||||
# get the total count of each revision
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
revision_counts[v] = revision_counts.get(v, 0) + 1
|
||||
|
||||
# get the revision with the most counts
|
||||
sorted_revision_counts = sorted(
|
||||
revision_counts.items(), key=lambda item: item[1], reverse=True
|
||||
)
|
||||
|
||||
if len(sorted_revision_counts) == 0:
|
||||
task_logger.error(
|
||||
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
|
||||
)
|
||||
else:
|
||||
top_revision, _ = sorted_revision_counts[0]
|
||||
|
||||
# build a list of out of date tenants
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v == top_revision:
|
||||
continue
|
||||
|
||||
out_of_date_tenants[k] = v
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud alembic check")
|
||||
raise
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error("cloud_check_alembic - Lock not owned on completion")
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
if len(out_of_date_tenants) > 0:
|
||||
task_logger.error(
|
||||
f"Found out of date tenants: "
|
||||
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"revision={top_revision}"
|
||||
)
|
||||
for k, v in islice(out_of_date_tenants.items(), 5):
|
||||
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
|
||||
else:
|
||||
task_logger.info(
|
||||
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -25,30 +25,21 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due.
|
||||
|
||||
@@ -87,7 +78,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -213,14 +203,6 @@ def try_creating_prune_generator_task(
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
except Exception:
|
||||
@@ -252,8 +234,6 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
@@ -367,52 +347,3 @@ def connector_pruning_generator_task(
|
||||
lock.release()
|
||||
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial,
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from tenacity import RetryError
|
||||
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.document import fetch_chunk_count_for_document
|
||||
@@ -25,16 +18,11 @@ from onyx.db.document import get_document_connector_count
|
||||
from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
@@ -75,18 +63,14 @@ def document_by_cc_pair_cleanup_task(
|
||||
"""
|
||||
task_logger.debug(f"Task start: doc={document_id}")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
@@ -156,13 +140,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
@@ -217,78 +199,3 @@ def document_by_cc_pair_cleanup_task(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_beat_task_generator(
|
||||
self: Task,
|
||||
task_name: str,
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# needed in the cloud
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
|
||||
finally:
|
||||
if not lock_beat.owned():
|
||||
task_logger.error(
|
||||
"cloud_beat_task_generator - Lock not owned on completion"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
else:
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -24,10 +24,6 @@ from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
monitor_ccpair_permissions_taskset,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -38,6 +34,8 @@ from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
@@ -63,22 +61,23 @@ from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
@@ -98,7 +97,6 @@ logger = setup_logger()
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -652,6 +650,83 @@ def monitor_connector_deletion_taskset(
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -660,7 +735,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -710,7 +785,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
@@ -756,7 +830,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -766,20 +840,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
@@ -796,20 +856,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
@@ -825,17 +876,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
@@ -895,19 +935,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
@@ -917,82 +955,70 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["documentset"] = time.monotonic() - phase_start
|
||||
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -1019,15 +1045,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
@@ -1073,16 +1095,9 @@ def vespa_metadata_sync_task(
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
# r.delete(redis_syncing_key)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
|
||||
@@ -35,7 +35,6 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
@@ -220,10 +219,9 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
primary_index_name=ctx.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
|
||||
@@ -254,7 +254,6 @@ def _get_force_search_settings(
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
new_msg_req.query_override is not None,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
@@ -426,7 +425,9 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
@@ -498,6 +499,14 @@ def stream_chat_message_objects(
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
|
||||
@@ -15,12 +15,11 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -32,16 +31,15 @@ def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
if not system_prompt:
|
||||
return None
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
@@ -66,11 +64,8 @@ def default_build_user_message(
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
return user_msg
|
||||
|
||||
@@ -91,7 +86,6 @@ class AnswerPromptBuilder:
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_config = llm_config
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -100,21 +94,12 @@ class AnswerPromptBuilder:
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(
|
||||
message_history,
|
||||
exclude_images=not model_supports_image_input(
|
||||
self.llm_config.model_name,
|
||||
self.llm_config.model_provider,
|
||||
),
|
||||
)
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
@@ -21,9 +21,9 @@ from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
from onyx.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
@@ -127,11 +127,10 @@ def build_citations_system_message(
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.include_citations:
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt, prompt_config, add_additional_info_if_no_tag=True
|
||||
)
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
|
||||
@@ -9,8 +9,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
@@ -39,11 +39,10 @@ def _build_strong_llm_quotes_prompt(
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
full_prompt, prompt, add_additional_info_if_no_tag=True
|
||||
)
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=tag_handled_prompt)
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.llm.utils import build_content_with_imgs
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
exclude_images: bool = False,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
@@ -19,9 +18,7 @@ def translate_onyx_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@@ -35,12 +32,9 @@ def translate_onyx_msg_to_langchain(
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
exclude_images: bool = False,
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_onyx_msg_to_langchain(msg, exclude_images)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -56,12 +55,12 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value)
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS")
|
||||
or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS")
|
||||
or 86400 * 7
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
@@ -93,12 +92,6 @@ OAUTH_CLIENT_SECRET = (
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
|
||||
# By default, this is set to match the Redis expiry time for consistency.
|
||||
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -200,8 +193,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# this assumes that other redis settings remain the same as the primary
|
||||
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
@@ -478,12 +469,6 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
|
||||
# Enable multi-threaded embedding model calls for parallel processing
|
||||
# Note: only applies for API-based embedding models
|
||||
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
|
||||
)
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
@@ -561,6 +546,13 @@ except json.JSONDecodeError:
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
# Multimodal-Settings
|
||||
# # enable usage of summaries
|
||||
# -> add summaries to Vespa when indexing and therefore use them in the answer generation as well
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED = (
|
||||
os.environ.get("CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
@@ -617,8 +609,3 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
# Set to true to mock LLM responses for testing purposes
|
||||
MOCK_LLM_RESPONSE = (
|
||||
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
|
||||
)
|
||||
|
||||
@@ -93,3 +93,29 @@ BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
# Custom System Prompt for Image Summarization
|
||||
# Results may be better if the prompt is in the language of the main language of the source
|
||||
# if no prompt provided by user a default prompt is used:
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = (
|
||||
os.environ.get("CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT")
|
||||
or """
|
||||
You are an assistant for summarizing images for retrieval.
|
||||
Summarize the content of the following image and be as precise as possible.
|
||||
The summary will be embedded and used to retrieve the original image.
|
||||
Therefore, write a concise summary of the image that is optimized for retrieval.
|
||||
"""
|
||||
)
|
||||
|
||||
# Custom User Prompt for Image Summarization
|
||||
# Results may be better if the prompt is in the language of the main language of the source
|
||||
# if no prompt provided by user a default prompt is used:
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT = (
|
||||
os.environ.get("CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT")
|
||||
or """
|
||||
The image has the file name '{title}' and is embedded on a Confluence page with the title '{page_title}'.
|
||||
Describe precisely and concisely what the image shows in the context of the page and what it is used for.
|
||||
The following is the XML source text of the page:
|
||||
{confluence_xml}
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -294,14 +294,11 @@ class OnyxRedisLocks:
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
|
||||
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
|
||||
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
|
||||
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
|
||||
|
||||
|
||||
class OnyxCeleryPriority(int, Enum):
|
||||
@@ -320,11 +317,6 @@ ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
|
||||
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
@@ -332,10 +324,8 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
@@ -353,6 +343,8 @@ class OnyxCeleryTask:
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import contextvars
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
@@ -24,9 +20,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# These field types are considered metadata by default when
|
||||
# treat_all_non_attachment_fields_as_metadata is False
|
||||
DEFAULT_METADATA_FIELD_TYPES = {
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
@@ -64,42 +60,21 @@ class AirtableConnector(LoadConnector):
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self._airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
treat_all_non_attachment_fields_as_metadata
|
||||
)
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
@property
|
||||
def airtable_client(self) -> AirtableApi:
|
||||
if not self._airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
def _extract_field_values(
|
||||
self,
|
||||
field_id: str,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) + links from a field regardless of its type.
|
||||
Attachments are represented as multiple sections, and therefore
|
||||
returned as a list of tuples (value, link).
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
@@ -110,11 +85,8 @@ class AirtableConnector(LoadConnector):
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
# default link to use for non-attachment fields
|
||||
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[tuple[str, str]] = []
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
@@ -127,37 +99,16 @@ class AirtableConnector(LoadConnector):
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str, record_id: str) -> bytes | None:
|
||||
try:
|
||||
attachment_response = requests.get(url)
|
||||
attachment_response.raise_for_status()
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 410:
|
||||
logger.info(f"Refreshing attachment for {filename}")
|
||||
# Re-fetch the record to get a fresh URL
|
||||
refreshed_record = self.airtable_client.table(
|
||||
base_id, table_id
|
||||
).get(record_id)
|
||||
for refreshed_attachment in refreshed_record["fields"][
|
||||
field_name
|
||||
]:
|
||||
if refreshed_attachment.get("filename") == filename:
|
||||
new_url = refreshed_attachment.get("url")
|
||||
if new_url:
|
||||
attachment_response = requests.get(new_url)
|
||||
attachment_response.raise_for_status()
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to refresh attachment for {filename}")
|
||||
|
||||
raise
|
||||
|
||||
attachment_content = get_attachment_with_retry(url, record_id)
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_id = attachment["id"]
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
@@ -165,20 +116,7 @@ class AirtableConnector(LoadConnector):
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
# slightly nicer loading experience if we can specify the view ID
|
||||
if view_id:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
else:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
attachment_texts.append(
|
||||
(f"{filename}:\n{attachment_text}", attachment_link)
|
||||
)
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
@@ -193,31 +131,23 @@ class AirtableConnector(LoadConnector):
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [(" ".join(combined) if combined else str(field_info), default_link)]
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [(item, default_link) for item in field_info]
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [(str(field_info), default_link)]
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata.
|
||||
|
||||
When treat_all_non_attachment_fields_as_metadata is True, all fields except
|
||||
attachments are treated as metadata. Otherwise, only fields with types listed
|
||||
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
|
||||
if self.treat_all_non_attachment_fields_as_metadata:
|
||||
return field_type.lower() != "multipleattachments"
|
||||
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_id: str,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
@@ -235,22 +165,12 @@ class AirtableConnector(LoadConnector):
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_value_and_links = self._extract_field_values(
|
||||
field_id=field_id,
|
||||
field_name=field_name,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=self.base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
if len(field_value_and_links) == 0:
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
field_values = [value for value, _ in field_value_and_links]
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
@@ -258,7 +178,7 @@ class AirtableConnector(LoadConnector):
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
@@ -266,7 +186,7 @@ class AirtableConnector(LoadConnector):
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text, link in field_value_and_links
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
@@ -275,7 +195,7 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document | None:
|
||||
) -> Document:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
@@ -299,35 +219,23 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
view_id = table_schema.views[0].id if table_schema.views else None
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
logger.debug(
|
||||
f"Processing field '{field_name}' of type '{field_type}' "
|
||||
f"for record '{record_id}'."
|
||||
)
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_id=field_schema.id,
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
if not sections:
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
@@ -364,47 +272,18 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents: list[Document] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record: dict[Future, RecordDict] = {}
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
future_to_record[
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
self._process_record,
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
] = record
|
||||
|
||||
# Wait for all tasks in this batch to complete
|
||||
for future in as_completed(future_to_record):
|
||||
record = future_to_record[future]
|
||||
try:
|
||||
document = future.result()
|
||||
if document:
|
||||
record_documents.append(document)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED
|
||||
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -29,6 +30,7 @@ from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
@@ -46,6 +48,7 @@ _ATTACHMENT_EXPANSION_FIELDS = [
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
]
|
||||
|
||||
_RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
@@ -129,7 +132,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
# If image summarization is enabled, but not supported by the default LLM, raise an error.
|
||||
if CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED:
|
||||
llm, _ = get_default_llms()
|
||||
if not llm.vision_support():
|
||||
raise ValueError(
|
||||
"The configured default LLM doesn't seem to have vision support for image summarization."
|
||||
)
|
||||
|
||||
self.llm = llm
|
||||
else:
|
||||
self.llm = None
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
@@ -194,67 +207,53 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
return comment_string
|
||||
|
||||
def _convert_object_to_document(
|
||||
self, confluence_object: dict[str, Any]
|
||||
) -> Document | None:
|
||||
def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None:
|
||||
"""
|
||||
Takes in a confluence object, extracts all metadata, and converts it into a document.
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
If it's a page, it extracts the text, adds the comments for the document text.
|
||||
If it's an attachment, it just downloads the attachment and converts that into a document.
|
||||
If image summarization is enabled (env var CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED is set),
|
||||
images are extracted and summarized by the default LLM.
|
||||
"""
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
object_text = None
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
return None
|
||||
logger.notice(f"processing page: {object_url}")
|
||||
|
||||
# Get space name
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": confluence_object["space"]["name"]
|
||||
"Wiki Space Name": page["space"]["name"]
|
||||
}
|
||||
|
||||
# Get labels
|
||||
label_dicts = (
|
||||
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
|
||||
)
|
||||
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
|
||||
label_dicts = page["metadata"]["labels"]["results"]
|
||||
page_labels = [label["name"] for label in label_dicts]
|
||||
if page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
# Get last modified and author email
|
||||
version_dict = confluence_object.get("version", {})
|
||||
last_modified = (
|
||||
datetime_from_string(version_dict.get("when"))
|
||||
if version_dict.get("when")
|
||||
else None
|
||||
)
|
||||
author_email = version_dict.get("by", {}).get("email")
|
||||
last_modified = datetime_from_string(page["version"]["when"])
|
||||
author_email = page["version"].get("by", {}).get("email")
|
||||
|
||||
title = confluence_object.get("title", "Untitled Document")
|
||||
# Extract text from page
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=page,
|
||||
fetched_titles={page.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(page["id"])
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parsable
|
||||
return None
|
||||
|
||||
return Document(
|
||||
id=object_url,
|
||||
sections=[Section(link=object_url, text=object_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=title,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author_email)] if author_email else None
|
||||
@@ -278,11 +277,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
doc = self._convert_page_to_document(page)
|
||||
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
@@ -291,16 +290,37 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
# TODO: maybe should add time filter as well?
|
||||
# fetch attachments of each page directly after each page
|
||||
# to be able to use the XML text of each page as context when summarizing the images of each page
|
||||
# (only if CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED = True, otherwise images will be skipped)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
|
||||
confluence_xml = page["body"]["storage"]["value"]
|
||||
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
content = attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_context=confluence_xml,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
if content:
|
||||
# get link to attachment in webui
|
||||
webui_link = _attachment_to_webui_link(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
# add prefix to text to mark attachments as such
|
||||
text = f"## Text representation of attachment {attachment['title']}:\n{content}"
|
||||
doc.sections.append(Section(text=text, link=webui_link))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -357,7 +377,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
media_type = attachment["metadata"]["mediaType"]
|
||||
if not validate_attachment_filetype(media_type):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
@@ -379,11 +400,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
def _attachment_to_webui_link(
|
||||
confluence_client: OnyxConfluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
"""Extracts the webui link to images."""
|
||||
return confluence_client.url + attachment["_links"]["webui"]
|
||||
|
||||
@@ -2,12 +2,18 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
import bs4 # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
from onyx.configs.chat_configs import CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT, CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.file_processing.image_summarization import summarize_image_pipeline
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
@@ -17,6 +23,7 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -179,36 +186,49 @@ def extract_text_from_confluence_html(
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]
|
||||
def validate_attachment_filetype(media_type: str) -> bool:
|
||||
if media_type.startswith("video/") or media_type == "application/gliffy+json":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
page_context: str,
|
||||
llm: LLM,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
media_type = attachment["metadata"]["mediaType"]
|
||||
|
||||
if media_type.startswith("video/") or media_type == "application/gliffy+json":
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
f"Cannot convert attachment {download_link} with unsupported media type to text: {media_type}."
|
||||
)
|
||||
return None
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
if CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED:
|
||||
try:
|
||||
# get images from page
|
||||
summarization = _summarize_image_attachment(
|
||||
attachment=attachment,
|
||||
page_context=page_context,
|
||||
confluence_client=confluence_client,
|
||||
llm=llm,
|
||||
)
|
||||
return summarization
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to summarize image: {attachment}", exc_info=e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Image summarization is disabled. Skipping image attachment %s",
|
||||
download_link,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
@@ -222,9 +242,17 @@ def attachment_to_content(
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
if not extracted_text:
|
||||
logger.warning(
|
||||
"Text conversion of attachment %s resulted in an empty string.",
|
||||
download_link,
|
||||
)
|
||||
return None
|
||||
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"Skipping attachment {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
@@ -279,3 +307,41 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def _attachment_to_download_link(
|
||||
confluence_client: OnyxConfluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
"""Extracts the download link to images."""
|
||||
return confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
|
||||
def _summarize_image_attachment(
|
||||
attachment: Dict[str, Any],
|
||||
page_context: str,
|
||||
confluence_client: OnyxConfluence,
|
||||
llm: LLM,
|
||||
) -> str:
|
||||
title = attachment["title"]
|
||||
download_link = _attachment_to_download_link(confluence_client, attachment)
|
||||
|
||||
try:
|
||||
# get image from url
|
||||
image_data = confluence_client.get(
|
||||
download_link, absolute=True, not_json_response=True
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image for summarization from {download_link}"
|
||||
) from e
|
||||
|
||||
# get image summary
|
||||
# format user prompt: add page title and XML content of page to provide a better summarization
|
||||
user_prompt = CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT.format(
|
||||
title=title, page_title=attachment["title"], confluence_xml=page_context
|
||||
)
|
||||
summary = summarize_image_pipeline(
|
||||
llm, image_data, user_prompt, CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
@@ -46,17 +45,7 @@ class ConnectorRunner:
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
yield from self.doc_batch_generator
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
|
||||
@@ -50,9 +50,6 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
current_link = ""
|
||||
current_text = ""
|
||||
|
||||
if transcript["sentences"] is None:
|
||||
return None
|
||||
|
||||
for sentence in transcript["sentences"]:
|
||||
if sentence["speaker_name"] != current_speaker_name:
|
||||
if current_speaker_name is not None:
|
||||
|
||||
@@ -150,16 +150,6 @@ class Document(DocumentBase):
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
source: DocumentSource
|
||||
|
||||
def get_total_char_length(self) -> int:
|
||||
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
|
||||
section_length = sum(len(section.text) for section in self.sections)
|
||||
identifier_length = len(self.semantic_identifier) + len(self.title or "")
|
||||
metadata_length = sum(
|
||||
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
|
||||
for k, v in self.metadata.items()
|
||||
)
|
||||
return section_length + identifier_length + metadata_length
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a document"""
|
||||
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import io
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from typing import Optional
|
||||
|
||||
import msal # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from office365.onedrive.sites.site import Site # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -27,25 +29,16 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
|
||||
Args:
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||
If None, all drives will be accessed.
|
||||
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||
If None, all folders will be accessed.
|
||||
"""
|
||||
|
||||
url: str
|
||||
drive_name: str | None
|
||||
folder_path: str | None
|
||||
@dataclass
|
||||
class SiteData:
|
||||
url: str | None
|
||||
folder: Optional[str]
|
||||
sites: list = field(default_factory=list)
|
||||
driveitems: list = field(default_factory=list)
|
||||
|
||||
|
||||
def _convert_driveitem_to_document(
|
||||
driveitem: DriveItem,
|
||||
drive_name: str,
|
||||
) -> Document:
|
||||
file_text = extract_file_text(
|
||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||
@@ -65,7 +58,7 @@ def _convert_driveitem_to_document(
|
||||
email=driveitem.last_modified_by.user.email,
|
||||
)
|
||||
],
|
||||
metadata={"drive": drive_name},
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
@@ -77,179 +70,93 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sites: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._graph_client: GraphClient | None = None
|
||||
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||
sites
|
||||
)
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
return self._graph_client
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
|
||||
|
||||
@staticmethod
|
||||
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
|
||||
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
|
||||
site_data_list = []
|
||||
for url in site_urls:
|
||||
parts = url.strip().split("/")
|
||||
if "sites" in parts:
|
||||
sites_index = parts.index("sites")
|
||||
site_url = "/".join(parts[: sites_index + 2])
|
||||
remaining_parts = parts[sites_index + 2 :]
|
||||
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
drive_name = unquote(remaining_parts[0])
|
||||
folder_path = (
|
||||
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||
if len(remaining_parts) > 1
|
||||
else None
|
||||
)
|
||||
else:
|
||||
drive_name = None
|
||||
folder_path = None
|
||||
|
||||
folder = (
|
||||
parts[sites_index + 2] if len(parts) > sites_index + 2 else None
|
||||
)
|
||||
site_data_list.append(
|
||||
SiteDescriptor(
|
||||
url=site_url,
|
||||
drive_name=drive_name,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
|
||||
)
|
||||
return site_data_list
|
||||
|
||||
def _fetch_driveitems(
|
||||
def _populate_sitedata_driveitems(
|
||||
self,
|
||||
site_descriptor: SiteDescriptor,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> list[tuple[DriveItem, str]]:
|
||||
final_driveitems: list[tuple[DriveItem, str]] = []
|
||||
try:
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
) -> None:
|
||||
filter_str = ""
|
||||
if start is not None and end is not None:
|
||||
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
|
||||
|
||||
# Get all drives in the site
|
||||
drives = site.drives.get().execute_query()
|
||||
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
|
||||
for element in self.site_data:
|
||||
sites: list[Site] = []
|
||||
for site in element.sites:
|
||||
site_sublist = site.lists.get().execute_query()
|
||||
sites.extend(site_sublist)
|
||||
|
||||
# Filter drives based on the requested drive name
|
||||
if site_descriptor.drive_name:
|
||||
drives = [
|
||||
drive
|
||||
for drive in drives
|
||||
if drive.name == site_descriptor.drive_name
|
||||
or (
|
||||
drive.name == "Documents"
|
||||
and site_descriptor.drive_name == "Shared Documents"
|
||||
)
|
||||
]
|
||||
if not drives:
|
||||
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
|
||||
return []
|
||||
|
||||
# Process each matching drive
|
||||
for drive in drives:
|
||||
for site in sites:
|
||||
try:
|
||||
root_folder = drive.root
|
||||
if site_descriptor.folder_path:
|
||||
# If a specific folder is requested, navigate to it
|
||||
for folder_part in site_descriptor.folder_path.split("/"):
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
# Get all items recursively
|
||||
query = root_folder.get_files(
|
||||
recursive=True,
|
||||
page_size=1000,
|
||||
)
|
||||
query = site.drive.root.get_files(True, 1000)
|
||||
if filter_str:
|
||||
query = query.filter(filter_str)
|
||||
driveitems = query.execute_query()
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items in drive '{drive.name}'"
|
||||
)
|
||||
|
||||
# Use "Shared Documents" as the library name for the default "Documents" drive
|
||||
drive_name = (
|
||||
"Shared Documents" if drive.name == "Documents" else drive.name
|
||||
)
|
||||
|
||||
# Filter items based on folder path if specified
|
||||
if site_descriptor.folder_path:
|
||||
# Filter items to ensure they're in the specified folder or its subfolders
|
||||
# The path will be in format: /drives/{drive_id}/root:/folder/path
|
||||
driveitems = [
|
||||
if element.folder:
|
||||
filtered_driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if any(
|
||||
path_part == site_descriptor.folder_path
|
||||
or path_part.startswith(
|
||||
site_descriptor.folder_path + "/"
|
||||
)
|
||||
for path_part in item.parent_reference.path.split(
|
||||
"root:/"
|
||||
)[1].split("/")
|
||||
)
|
||||
if element.folder in item.parent_reference.path
|
||||
]
|
||||
if len(driveitems) == 0:
|
||||
all_paths = [
|
||||
item.parent_reference.path for item in driveitems
|
||||
]
|
||||
logger.warning(
|
||||
f"Nothing found for folder '{site_descriptor.folder_path}' "
|
||||
f"in; any of valid paths: {all_paths}"
|
||||
)
|
||||
element.driveitems.extend(filtered_driveitems)
|
||||
else:
|
||||
element.driveitems.extend(driveitems)
|
||||
|
||||
# Filter items based on time window if specified
|
||||
if start is not None and end is not None:
|
||||
driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if start
|
||||
<= item.last_modified_datetime.replace(tzinfo=timezone.utc)
|
||||
<= end
|
||||
]
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items within time window in drive '{drive.name}'"
|
||||
)
|
||||
except Exception:
|
||||
# Sites include things that do not contain .drive.root so this fails
|
||||
# but this is fine, as there are no actually documents in those
|
||||
pass
|
||||
|
||||
for item in driveitems:
|
||||
final_driveitems.append((item, drive_name))
|
||||
def _populate_sitedata_sites(self) -> None:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
except Exception as e:
|
||||
# Some drives might not be accessible
|
||||
logger.warning(f"Failed to process drive: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
# Sites include things that do not contain drives so this fails
|
||||
# but this is fine, as there are no actual documents in those
|
||||
logger.warning(f"Failed to process site: {str(e)}")
|
||||
|
||||
return final_driveitems
|
||||
|
||||
def _fetch_sites(self) -> list[SiteDescriptor]:
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
site_descriptors = [
|
||||
SiteDescriptor(
|
||||
url=sites.resource_url,
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
)
|
||||
]
|
||||
return site_descriptors
|
||||
if self.site_data:
|
||||
for element in self.site_data:
|
||||
element.sites = [
|
||||
self.graph_client.sites.get_by_url(element.url)
|
||||
.get()
|
||||
.execute_query()
|
||||
]
|
||||
else:
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
self.site_data = [
|
||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
||||
]
|
||||
|
||||
def _fetch_from_sharepoint(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
site_descriptors = self.site_descriptors or self._fetch_sites()
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
self._populate_sitedata_sites()
|
||||
self._populate_sitedata_driveitems(start=start, end=end)
|
||||
|
||||
# goes over all urls, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for site_descriptor in site_descriptors:
|
||||
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
|
||||
for driveitem, drive_name in driveitems:
|
||||
for element in self.site_data:
|
||||
for driveitem in element.driveitems:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
|
||||
doc_batch.append(_convert_driveitem_to_document(driveitem))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
@@ -261,26 +168,22 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sp_client_secret = credentials["sp_client_secret"]
|
||||
sp_directory_id = credentials["sp_directory_id"]
|
||||
|
||||
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=sp_client_id,
|
||||
client_credential=sp_client_secret,
|
||||
)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
|
||||
app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=sp_client_id,
|
||||
client_credential=sp_client_secret,
|
||||
)
|
||||
token = app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
self._graph_client = GraphClient(_acquire_token_func)
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
@@ -289,19 +192,19 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
|
||||
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||
"sp_client_id": os.environ["SP_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
|
||||
@@ -104,11 +104,8 @@ def make_slack_api_rate_limited(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# Log internal_error and return the response instead of failing
|
||||
logger.warning(
|
||||
f"Slack call encountered '{error}', skipping and continuing..."
|
||||
)
|
||||
elif error in ["already_reacted", "no_reaction"]:
|
||||
# The response isn't used for reactions, this is basically just a pass
|
||||
return e.response
|
||||
else:
|
||||
# Raise the error for non-transient errors
|
||||
|
||||
@@ -180,28 +180,23 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.requested_team_list: list[str] = teams
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
teams_client_id = credentials["teams_client_id"]
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
|
||||
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
client_credential=teams_client_secret,
|
||||
)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
|
||||
app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
client_credential=teams_client_secret,
|
||||
)
|
||||
token = app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -67,7 +67,10 @@ class SearchPipeline:
|
||||
self.rerank_metrics_callback = rerank_metrics_callback
|
||||
|
||||
self.search_settings = get_current_search_settings(db_session)
|
||||
self.document_index = get_default_document_index(self.search_settings, None)
|
||||
self.document_index = get_default_document_index(
|
||||
primary_index_name=self.search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
self.prompt_config: PromptConfig | None = prompt_config
|
||||
|
||||
# Preprocessing steps generate this
|
||||
|
||||
@@ -28,9 +28,6 @@ class SyncType(str, PyEnum):
|
||||
DOCUMENT_SET = "document_set"
|
||||
USER_GROUP = "user_group"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
PRUNING = "pruning" # not really a sync, but close enough
|
||||
EXTERNAL_PERMISSIONS = "external_permissions"
|
||||
EXTERNAL_GROUP = "external_group"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@@ -432,7 +432,7 @@ def get_paginated_index_attempts_for_cc_pair_id(
|
||||
stmt = stmt.order_by(IndexAttempt.time_started.desc())
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
@@ -193,13 +193,13 @@ def fetch_input_prompts_by_user(
|
||||
"""
|
||||
Returns all prompts belonging to the user or public prompts,
|
||||
excluding those the user has specifically disabled.
|
||||
Also, if `user_id` is None and AUTH_TYPE is DISABLED, then all prompts are returned.
|
||||
"""
|
||||
|
||||
# Start with a basic query for InputPrompt
|
||||
query = select(InputPrompt)
|
||||
|
||||
# If we have a user, left join to InputPrompt__User so we can check "disabled"
|
||||
if user_id is not None:
|
||||
# If we have a user, left join to InputPrompt__User to check "disabled"
|
||||
IPU = aliased(InputPrompt__User)
|
||||
query = query.join(
|
||||
IPU,
|
||||
@@ -208,30 +208,25 @@ def fetch_input_prompts_by_user(
|
||||
)
|
||||
|
||||
# Exclude disabled prompts
|
||||
# i.e. keep only those where (IPU.disabled is NULL or False)
|
||||
query = query.where(or_(IPU.disabled.is_(None), IPU.disabled.is_(False)))
|
||||
|
||||
if include_public:
|
||||
# Return both user-owned and public prompts
|
||||
# user-owned or public
|
||||
query = query.where(
|
||||
or_(
|
||||
InputPrompt.user_id == user_id,
|
||||
InputPrompt.is_public,
|
||||
)
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.is_public)
|
||||
)
|
||||
else:
|
||||
# Return only user-owned prompts
|
||||
# only user-owned prompts
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
else:
|
||||
# user_id is None
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# If auth is disabled, return all prompts
|
||||
query = query.where(True) # type: ignore
|
||||
elif include_public:
|
||||
# Anonymous usage
|
||||
query = query.where(InputPrompt.is_public)
|
||||
# If no user is logged in, get all prompts (public and private)
|
||||
if user_id is None and AUTH_TYPE == AuthType.DISABLED:
|
||||
query = query.where(True) # type: ignore
|
||||
|
||||
# Default to returning all prompts
|
||||
# If no user is logged in but we want to include public prompts
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
@@ -3,8 +3,6 @@ from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
@@ -126,29 +124,10 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
) -> list[LLMProviderModel]:
|
||||
stmt = select(LLMProviderModel)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers_for_user(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
if AUTH_TYPE != AuthType.DISABLED:
|
||||
# User is anonymous
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_public == True # noqa: E712
|
||||
)
|
||||
).all()
|
||||
)
|
||||
else:
|
||||
# If auth is disabled, user has access to all providers
|
||||
return fetch_existing_llm_providers(db_session)
|
||||
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_select = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
|
||||
@@ -150,7 +150,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
@@ -162,7 +161,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
hidden_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
|
||||
recent_assistants: Mapped[list[dict]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
|
||||
)
|
||||
pinned_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
@@ -746,34 +747,6 @@ class SearchSettings(Base):
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
|
||||
@property
|
||||
def large_chunks_enabled(self) -> bool:
|
||||
"""
|
||||
Given multipass usage and an embedder, decides whether large chunks are allowed
|
||||
based on model/provider constraints.
|
||||
"""
|
||||
# Only local models that support a larger context are from Nomic
|
||||
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
|
||||
return SearchSettings.can_use_large_chunks(
|
||||
self.multipass_indexing, self.model_name, self.provider_type
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def can_use_large_chunks(
|
||||
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
|
||||
) -> bool:
|
||||
"""
|
||||
Given multipass usage and an embedder, decides whether large chunks are allowed
|
||||
based on model/provider constraints.
|
||||
"""
|
||||
# Only local models that support a larger context are from Nomic
|
||||
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
|
||||
return (
|
||||
multipass
|
||||
and model_name.startswith("nomic-ai")
|
||||
and provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
@@ -1116,10 +1089,6 @@ class ChatSession(Base):
|
||||
llm_override: Mapped[LLMOverride | None] = mapped_column(
|
||||
PydanticType(LLMOverride), nullable=True
|
||||
)
|
||||
|
||||
# The latest temperature override specified by the user
|
||||
temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
@@ -1461,8 +1430,6 @@ class Tool(Base):
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# whether to pass through the user's OAuth token as Authorization header
|
||||
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -291,9 +291,8 @@ def get_personas_for_user(
|
||||
include_deleted: bool = False,
|
||||
joinedload_all: bool = False,
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
stmt = select(Persona).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
@@ -303,16 +302,14 @@ def get_personas_for_user(
|
||||
|
||||
if joinedload_all:
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.prompts),
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.groups),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.labels),
|
||||
joinedload(Persona.prompts),
|
||||
joinedload(Persona.tools),
|
||||
joinedload(Persona.document_sets),
|
||||
joinedload(Persona.groups),
|
||||
joinedload(Persona.users),
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return results
|
||||
return db_session.execute(stmt).unique().scalars().all()
|
||||
|
||||
|
||||
def get_personas(db_session: Session) -> Sequence[Persona]:
|
||||
|
||||
@@ -29,21 +29,9 @@ from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ActiveSearchSettings:
|
||||
primary: SearchSettings
|
||||
secondary: SearchSettings | None
|
||||
|
||||
def __init__(
|
||||
self, primary: SearchSettings, secondary: SearchSettings | None
|
||||
) -> None:
|
||||
self.primary = primary
|
||||
self.secondary = secondary
|
||||
|
||||
|
||||
def create_search_settings(
|
||||
search_settings: SavedSearchSettings,
|
||||
db_session: Session,
|
||||
@@ -155,27 +143,21 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_active_search_settings(db_session: Session) -> ActiveSearchSettings:
|
||||
"""Returns active search settings. Secondary search settings may be None."""
|
||||
|
||||
# Get the primary and secondary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
return ActiveSearchSettings(
|
||||
primary=primary_search_settings, secondary=secondary_search_settings
|
||||
)
|
||||
|
||||
|
||||
def get_active_search_settings_list(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings as a list. Primary settings are the first element,
|
||||
and if secondary search settings exist, they will be the second element."""
|
||||
|
||||
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings. The first entry will always be the current search
|
||||
settings. If there are new search settings that are being migrated to, those will be
|
||||
the second entry."""
|
||||
search_settings_list: list[SearchSettings] = []
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
search_settings_list.append(active_search_settings.primary)
|
||||
if active_search_settings.secondary:
|
||||
search_settings_list.append(active_search_settings.secondary)
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings_list.append(primary_search_settings)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings_list.append(secondary_search_settings)
|
||||
|
||||
return search_settings_list
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.tools.built_in_tools import get_search_tool
|
||||
from onyx.utils.errors import EERequiredError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -48,10 +47,6 @@ def create_slack_channel_persona(
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
search_tool = get_search_tool(db_session)
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool not found")
|
||||
|
||||
# create/update persona associated with the Slack channel
|
||||
persona_name = _build_persona_name(channel_name)
|
||||
default_prompt = get_default_prompt(db_session)
|
||||
@@ -65,7 +60,6 @@ def create_slack_channel_persona(
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompt_ids=[default_prompt.id],
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
|
||||
@@ -8,64 +8,20 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.setup import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Insert a new sync record into the database, cancelling any existing in-progress records.
|
||||
"""Insert a new sync record into the database.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
|
||||
sync_type: The type of sync operation
|
||||
"""
|
||||
# If an existing in-progress sync record exists, mark as cancelled
|
||||
existing_in_progress_sync_record = fetch_latest_sync_record(
|
||||
db_session, entity_id, sync_type, sync_status=SyncStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
if existing_in_progress_sync_record is not None:
|
||||
logger.info(
|
||||
f"Cancelling existing in-progress sync record {existing_in_progress_sync_record.id} "
|
||||
f"for entity_id={entity_id} sync_type={sync_type}"
|
||||
)
|
||||
mark_sync_records_as_cancelled(db_session, entity_id, sync_type)
|
||||
|
||||
return _create_sync_record(db_session, entity_id, sync_type)
|
||||
|
||||
|
||||
def mark_sync_records_as_cancelled(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(SyncRecord)
|
||||
.where(
|
||||
and_(
|
||||
SyncRecord.entity_id == entity_id,
|
||||
SyncRecord.sync_type == sync_type,
|
||||
SyncRecord.sync_status == SyncStatus.IN_PROGRESS,
|
||||
)
|
||||
)
|
||||
.values(sync_status=SyncStatus.CANCELED)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _create_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Create and insert a new sync record into the database."""
|
||||
sync_record = SyncRecord(
|
||||
entity_id=entity_id,
|
||||
sync_type=sync_type,
|
||||
@@ -83,7 +39,6 @@ def fetch_latest_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
sync_type: SyncType,
|
||||
sync_status: SyncStatus | None = None,
|
||||
) -> SyncRecord | None:
|
||||
"""Fetch the most recent sync record for a given entity ID and status.
|
||||
|
||||
@@ -104,9 +59,6 @@ def fetch_latest_sync_record(
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if sync_status is not None:
|
||||
stmt = stmt.where(SyncRecord.sync_status == sync_status)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ def create_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
name=name,
|
||||
@@ -49,7 +48,6 @@ def create_tool(
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
passthrough_auth=passthrough_auth,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
db_session.commit()
|
||||
@@ -64,7 +62,6 @@ def update_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool | None,
|
||||
) -> Tool:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
@@ -82,8 +79,6 @@ def update_tool(
|
||||
tool.custom_headers = [
|
||||
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
@@ -4,63 +4,24 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import MultipassConfig
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
DEFAULT_BATCH_SIZE = 30
|
||||
DEFAULT_INDEX_NAME = "danswer_chunk"
|
||||
|
||||
|
||||
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
|
||||
"""
|
||||
Determines whether multipass should be used based on the search settings
|
||||
or the default config if settings are unavailable.
|
||||
"""
|
||||
if search_settings is not None:
|
||||
return search_settings.multipass_indexing
|
||||
return ENABLE_MULTIPASS_INDEXING
|
||||
|
||||
|
||||
def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
"""
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
)
|
||||
return MultipassConfig(
|
||||
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
|
||||
)
|
||||
|
||||
|
||||
def get_both_index_properties(
|
||||
db_session: Session,
|
||||
) -> tuple[str, str | None, bool, bool | None]:
|
||||
def get_both_index_names(db_session: Session) -> tuple[str, str | None]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
config_1 = get_multipass_config(search_settings)
|
||||
|
||||
search_settings_new = get_secondary_search_settings(db_session)
|
||||
if not search_settings_new:
|
||||
return search_settings.index_name, None, config_1.enable_large_chunks, None
|
||||
return search_settings.index_name, None
|
||||
|
||||
config_2 = get_multipass_config(search_settings)
|
||||
return (
|
||||
search_settings.index_name,
|
||||
search_settings_new.index_name,
|
||||
config_1.enable_large_chunks,
|
||||
config_2.enable_large_chunks,
|
||||
)
|
||||
return search_settings.index_name, search_settings_new.index_name
|
||||
|
||||
|
||||
def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
@@ -9,28 +7,17 @@ from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def get_default_document_index(
|
||||
search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings | None,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
primary_index_name: str,
|
||||
secondary_index_name: str | None,
|
||||
) -> DocumentIndex:
|
||||
"""Primary index is the index that is used for querying/updating etc.
|
||||
Secondary index is for when both the currently used index and the upcoming
|
||||
index both need to be updated, updates are applied to both indices"""
|
||||
|
||||
secondary_index_name: str | None = None
|
||||
secondary_large_chunks_enabled: bool | None = None
|
||||
if secondary_search_settings:
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
# Currently only supporting Vespa
|
||||
return VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
index_name=primary_index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +27,6 @@ def get_current_primary_default_document_index(db_session: Session) -> DocumentI
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return get_default_document_index(
|
||||
search_settings,
|
||||
None,
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
|
||||
@@ -231,22 +231,21 @@ def _get_chunks_via_visit_api(
|
||||
return document_chunks
|
||||
|
||||
|
||||
# TODO(rkuo): candidate for removal if not being used
|
||||
# @retry(tries=10, delay=1, backoff=2)
|
||||
# def get_all_vespa_ids_for_document_id(
|
||||
# document_id: str,
|
||||
# index_name: str,
|
||||
# filters: IndexFilters | None = None,
|
||||
# get_large_chunks: bool = False,
|
||||
# ) -> list[str]:
|
||||
# document_chunks = _get_chunks_via_visit_api(
|
||||
# chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
# index_name=index_name,
|
||||
# filters=filters or IndexFilters(access_control_list=None),
|
||||
# field_names=[DOCUMENT_ID],
|
||||
# get_large_chunks=get_large_chunks,
|
||||
# )
|
||||
# return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
def get_all_vespa_ids_for_document_id(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
filters: IndexFilters | None = None,
|
||||
get_large_chunks: bool = False,
|
||||
) -> list[str]:
|
||||
document_chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=filters or IndexFilters(access_control_list=None),
|
||||
field_names=[DOCUMENT_ID],
|
||||
get_large_chunks=get_large_chunks,
|
||||
)
|
||||
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
|
||||
|
||||
def parallel_visit_api_retrieval(
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
@@ -40,12 +41,12 @@ from onyx.document_index.vespa.chunk_retrieval import (
|
||||
)
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
@@ -131,34 +132,12 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
index_name: str,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
secondary_large_chunks_enabled: bool | None,
|
||||
multitenant: bool = False,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> None:
|
||||
self.index_name = index_name
|
||||
self.secondary_index_name = secondary_index_name
|
||||
|
||||
self.large_chunks_enabled = large_chunks_enabled
|
||||
self.secondary_large_chunks_enabled = secondary_large_chunks_enabled
|
||||
|
||||
self.multitenant = multitenant
|
||||
|
||||
self.httpx_client_context: BaseHTTPXClientContext
|
||||
|
||||
if httpx_client:
|
||||
self.httpx_client_context = GlobalHTTPXClientContext(httpx_client)
|
||||
else:
|
||||
self.httpx_client_context = TemporaryHTTPXClientContext(
|
||||
get_vespa_http_client
|
||||
)
|
||||
|
||||
self.index_to_large_chunks_enabled: dict[str, bool] = {}
|
||||
self.index_to_large_chunks_enabled[index_name] = large_chunks_enabled
|
||||
if secondary_index_name and secondary_large_chunks_enabled:
|
||||
self.index_to_large_chunks_enabled[
|
||||
secondary_index_name
|
||||
] = secondary_large_chunks_enabled
|
||||
self.http_client = get_vespa_http_client()
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
@@ -352,7 +331,7 @@ class VespaIndex(DocumentIndex):
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
self.httpx_client_context as http_client,
|
||||
get_vespa_http_client() as http_client,
|
||||
):
|
||||
# We require the start and end index for each document in order to
|
||||
# know precisely which chunks to delete. This information exists for
|
||||
@@ -411,11 +390,9 @@ class VespaIndex(DocumentIndex):
|
||||
for doc_id in all_doc_ids
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def _apply_updates_batched(
|
||||
cls,
|
||||
updates: list[_VespaUpdateRequest],
|
||||
httpx_client: httpx.Client,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
|
||||
@@ -437,7 +414,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
httpx_client as http_client,
|
||||
get_vespa_http_client() as http_client,
|
||||
):
|
||||
for update_batch in batch_generator(updates, batch_size):
|
||||
future_to_document_id = {
|
||||
@@ -478,7 +455,7 @@ class VespaIndex(DocumentIndex):
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
chunk_id_start_time = time.monotonic()
|
||||
with self.httpx_client_context as http_client:
|
||||
with get_vespa_http_client() as http_client:
|
||||
for update_request in update_requests:
|
||||
for doc_info in update_request.minimal_document_indexing_info:
|
||||
for index_name in index_names:
|
||||
@@ -534,8 +511,7 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
)
|
||||
|
||||
with self.httpx_client_context as httpx_client:
|
||||
self._apply_updates_batched(processed_updates_requests, httpx_client)
|
||||
self._apply_updates_batched(processed_updates_requests)
|
||||
logger.debug(
|
||||
"Finished updating Vespa documents in %.2f seconds",
|
||||
time.monotonic() - update_start,
|
||||
@@ -547,7 +523,6 @@ class VespaIndex(DocumentIndex):
|
||||
index_name: str,
|
||||
fields: VespaDocumentFields,
|
||||
doc_id: str,
|
||||
http_client: httpx.Client,
|
||||
) -> None:
|
||||
"""
|
||||
Update a single "chunk" (document) in Vespa using its chunk ID.
|
||||
@@ -579,17 +554,18 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
|
||||
|
||||
try:
|
||||
resp = http_client.put(
|
||||
vespa_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
|
||||
logger.error(error_message)
|
||||
raise
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
try:
|
||||
resp = http_client.put(
|
||||
vespa_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
|
||||
logger.error(error_message)
|
||||
raise
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -603,16 +579,24 @@ class VespaIndex(DocumentIndex):
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
"""
|
||||
|
||||
doc_chunk_count = 0
|
||||
|
||||
with self.httpx_client_context as httpx_client:
|
||||
for (
|
||||
index_name,
|
||||
large_chunks_enabled,
|
||||
) in self.index_to_large_chunks_enabled.items():
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
for index_name in index_names:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
multipass_config = get_multipass_config(
|
||||
db_session=db_session,
|
||||
primary_index=index_name == self.index_name,
|
||||
)
|
||||
large_chunks_enabled = multipass_config.enable_large_chunks
|
||||
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=httpx_client,
|
||||
http_client=http_client,
|
||||
document_id=doc_id,
|
||||
previous_chunk_count=chunk_count,
|
||||
new_chunk_count=0,
|
||||
@@ -628,7 +612,10 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
for doc_chunk_id in doc_chunk_ids:
|
||||
self.update_single_chunk(
|
||||
doc_chunk_id, index_name, fields, doc_id, httpx_client
|
||||
doc_chunk_id=doc_chunk_id,
|
||||
index_name=index_name,
|
||||
fields=fields,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
|
||||
return doc_chunk_count
|
||||
@@ -650,13 +637,19 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with self.httpx_client_context as http_client, concurrent.futures.ThreadPoolExecutor(
|
||||
with get_vespa_http_client(
|
||||
http2=False
|
||||
) as http_client, concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=NUM_THREADS
|
||||
) as executor:
|
||||
for (
|
||||
index_name,
|
||||
large_chunks_enabled,
|
||||
) in self.index_to_large_chunks_enabled.items():
|
||||
for index_name in index_names:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
multipass_config = get_multipass_config(
|
||||
db_session=db_session,
|
||||
primary_index=index_name == self.index_name,
|
||||
)
|
||||
large_chunks_enabled = multipass_config.enable_large_chunks
|
||||
|
||||
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=http_client,
|
||||
@@ -825,9 +818,6 @@ class VespaIndex(DocumentIndex):
|
||||
"""
|
||||
Deletes all entries in the specified index with the given tenant_id.
|
||||
|
||||
Currently unused, but we anticipate this being useful. The entire flow does not
|
||||
use the httpx connection pool of an instance.
|
||||
|
||||
Parameters:
|
||||
tenant_id (str): The tenant ID whose documents are to be deleted.
|
||||
index_name (str): The name of the index from which to delete documents.
|
||||
@@ -860,8 +850,6 @@ class VespaIndex(DocumentIndex):
|
||||
"""
|
||||
Retrieves all document IDs with the specified tenant_id, handling pagination.
|
||||
|
||||
Internal helper function for delete_entries_by_tenant_id.
|
||||
|
||||
Parameters:
|
||||
tenant_id (str): The tenant ID to search for.
|
||||
index_name (str): The name of the index to search in.
|
||||
@@ -894,8 +882,8 @@ class VespaIndex(DocumentIndex):
|
||||
f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}"
|
||||
)
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=query_params, timeout=None)
|
||||
with get_vespa_http_client(no_timeout=True) as http_client:
|
||||
response = http_client.get(url, params=query_params)
|
||||
response.raise_for_status()
|
||||
|
||||
search_result = response.json()
|
||||
@@ -925,11 +913,6 @@ class VespaIndex(DocumentIndex):
|
||||
"""
|
||||
Deletes documents in batches using multiple threads.
|
||||
|
||||
Internal helper function for delete_entries_by_tenant_id.
|
||||
|
||||
This is a class method and does not use the httpx pool of the instance.
|
||||
This is OK because we don't use this method often.
|
||||
|
||||
Parameters:
|
||||
delete_requests (List[_VespaDeleteRequest]): The list of delete requests.
|
||||
batch_size (int): The number of documents to delete in each batch.
|
||||
@@ -942,14 +925,13 @@ class VespaIndex(DocumentIndex):
|
||||
response = http_client.delete(
|
||||
delete_request.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=None,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.debug(f"Starting batch deletion for {len(delete_requests)} documents")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(no_timeout=True) as http_client:
|
||||
for batch_start in range(0, len(delete_requests), batch_size):
|
||||
batch = delete_requests[batch_start : batch_start + batch_size]
|
||||
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
from retry import retry
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
@@ -48,9 +50,10 @@ from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import TITLE
|
||||
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import EmbeddingProvider
|
||||
from onyx.indexing.models import MultipassConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -272,42 +275,46 @@ def check_for_final_chunk_existence(
|
||||
index += 1
|
||||
|
||||
|
||||
class BaseHTTPXClientContext(ABC):
|
||||
"""Abstract base class for an HTTPX client context manager."""
|
||||
|
||||
@abstractmethod
|
||||
def __enter__(self) -> httpx.Client:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
pass
|
||||
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
|
||||
"""
|
||||
Determines whether multipass should be used based on the search settings
|
||||
or the default config if settings are unavailable.
|
||||
"""
|
||||
if search_settings is not None:
|
||||
return search_settings.multipass_indexing
|
||||
return ENABLE_MULTIPASS_INDEXING
|
||||
|
||||
|
||||
class GlobalHTTPXClientContext(BaseHTTPXClientContext):
|
||||
"""Context manager for a global HTTPX client that does not close it."""
|
||||
|
||||
def __init__(self, client: httpx.Client):
|
||||
self._client = client
|
||||
|
||||
def __enter__(self) -> httpx.Client:
|
||||
return self._client # Reuse the global client
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
pass # Do nothing; don't close the global client
|
||||
def can_use_large_chunks(multipass: bool, search_settings: SearchSettings) -> bool:
|
||||
"""
|
||||
Given multipass usage and an embedder, decides whether large chunks are allowed
|
||||
based on model/provider constraints.
|
||||
"""
|
||||
# Only local models that support a larger context are from Nomic
|
||||
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
|
||||
return (
|
||||
multipass
|
||||
and search_settings.model_name.startswith("nomic-ai")
|
||||
and search_settings.provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
|
||||
|
||||
class TemporaryHTTPXClientContext(BaseHTTPXClientContext):
|
||||
"""Context manager for a temporary HTTPX client that closes it after use."""
|
||||
|
||||
def __init__(self, client_factory: Callable[[], httpx.Client]):
|
||||
self._client_factory = client_factory
|
||||
self._client: httpx.Client | None = None # Client will be created in __enter__
|
||||
|
||||
def __enter__(self) -> httpx.Client:
|
||||
self._client = self._client_factory() # Create a new client
|
||||
return self._client
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
if self._client:
|
||||
self._client.close()
|
||||
def get_multipass_config(
|
||||
db_session: Session, primary_index: bool = True
|
||||
) -> MultipassConfig:
|
||||
"""
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
search_settings = (
|
||||
get_current_search_settings(db_session)
|
||||
if primary_index
|
||||
else get_secondary_search_settings(db_session)
|
||||
)
|
||||
multipass = should_use_multipass(search_settings)
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
enable_large_chunks = can_use_large_chunks(multipass, search_settings)
|
||||
return MultipassConfig(
|
||||
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
@@ -8,10 +7,6 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code
|
||||
# See here for reference: https://docs.vespa.ai/en/documents.html
|
||||
@@ -55,7 +50,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
"""Vespa does not take in unicode chars that aren't valid for XML.
|
||||
This removes them."""
|
||||
_illegal_xml_chars_RE: re.Pattern = re.compile(
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFDD0-\uFDEF\uFFFE\uFFFF]"
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
|
||||
)
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
@@ -74,37 +69,3 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
|
||||
def wait_for_vespa_with_timeout(wait_interval: int = 5, wait_limit: int = 60) -> bool:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Returns True if Vespa is ready, False otherwise."""
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_APP_CONTAINER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > wait_limit:
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({wait_limit} seconds)."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
@@ -358,13 +358,7 @@ def extract_file_text(
|
||||
|
||||
try:
|
||||
if get_unstructured_api_key():
|
||||
try:
|
||||
return unstructured_to_text(file, file_name)
|
||||
except Exception as unstructured_error:
|
||||
logger.error(
|
||||
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
|
||||
)
|
||||
# Fall through to normal processing
|
||||
return unstructured_to_text(file, file_name)
|
||||
|
||||
if file_name or extension:
|
||||
if extension is not None:
|
||||
|
||||
103
backend/onyx/file_processing/image_summarization.py
Normal file
103
backend/onyx/file_processing/image_summarization.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def summarize_image_pipeline(
|
||||
llm: LLM,
|
||||
image_data: bytes,
|
||||
query: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | None:
|
||||
"""Pipeline to generate a summary of an image.
|
||||
Resizes images if it is bigger than 20MB. Encodes image as a base64 string.
|
||||
And finally uses the Default LLM to generate a textual summary of the image."""
|
||||
# resize image if it's bigger than 20MB
|
||||
image_data = _resize_image_if_needed(image_data)
|
||||
|
||||
# encode image (base64)
|
||||
encoded_image = _encode_image(image_data)
|
||||
|
||||
summary = _summarize_image(
|
||||
encoded_image,
|
||||
llm,
|
||||
query,
|
||||
system_prompt,
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _summarize_image(
|
||||
encoded_image: str,
|
||||
llm: LLM,
|
||||
query: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | None:
|
||||
"""Use default LLM (if it is multimodal) to generate a summary of an image."""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": query},
|
||||
{"type": "image_url", "image_url": {"url": encoded_image}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
model_output = message_to_string(llm.invoke(messages))
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
if CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
# Summary of this image will be empty
|
||||
# prevents and infinity retry-loop of the indexing if single summaries fail
|
||||
# for example because content filters got triggert...
|
||||
logger.warning(f"Summarization failed with error: {e}.")
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Summarization failed. Messages: {messages}") from e
|
||||
|
||||
|
||||
def _encode_image(image_data: bytes) -> str:
|
||||
"""Getting the base64 string."""
|
||||
base64_encoded_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
return f"data:image/jpeg;base64,{base64_encoded_data}"
|
||||
|
||||
|
||||
def _resize_image_if_needed(image_data: bytes, max_size_mb: int = 20) -> bytes:
|
||||
"""Resize image if it's larger than the specified max size in MB."""
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
|
||||
if len(image_data) > max_size_bytes:
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
logger.info("resizing image...")
|
||||
|
||||
# Reduce dimensions for better size reduction
|
||||
img.thumbnail((800, 800), Image.Resampling.LANCZOS)
|
||||
output = BytesIO()
|
||||
|
||||
# Save with lower quality for compression
|
||||
img.save(output, format="JPEG", quality=85)
|
||||
resized_data = output.getvalue()
|
||||
|
||||
return resized_data
|
||||
|
||||
return image_data
|
||||
@@ -52,7 +52,7 @@ def _sdk_partition_request(
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
req = _sdk_partition_request(file, file_name, strategy="auto")
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class HttpxPool:
|
||||
"""Class to manage a global httpx Client instance"""
|
||||
|
||||
_clients: dict[str, httpx.Client] = {}
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# Default parameters for creation
|
||||
DEFAULT_KWARGS = {
|
||||
"http2": True,
|
||||
"limits": lambda: httpx.Limits(),
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_client(cls, **kwargs: Any) -> httpx.Client:
|
||||
"""Private helper method to create and return an httpx.Client."""
|
||||
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
|
||||
return httpx.Client(**merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
def init_client(cls, name: str, **kwargs: Any) -> None:
|
||||
"""Allow the caller to init the client with extra params."""
|
||||
with cls._lock:
|
||||
if name not in cls._clients:
|
||||
cls._clients[name] = cls._init_client(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def close_client(cls, name: str) -> None:
|
||||
"""Allow the caller to close the client."""
|
||||
with cls._lock:
|
||||
client = cls._clients.pop(name, None)
|
||||
if client:
|
||||
client.close()
|
||||
|
||||
@classmethod
|
||||
def close_all(cls) -> None:
|
||||
"""Close all registered clients."""
|
||||
with cls._lock:
|
||||
for client in cls._clients.values():
|
||||
client.close()
|
||||
cls._clients.clear()
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> httpx.Client:
|
||||
"""Gets the httpx.Client. Will init to default settings if not init'd."""
|
||||
with cls._lock:
|
||||
if name not in cls._clients:
|
||||
cls._clients[name] = cls._init_client()
|
||||
return cls._clients[name]
|
||||
@@ -225,6 +225,7 @@ class Chunker:
|
||||
|
||||
for section_idx, section in enumerate(document.sections):
|
||||
section_text = clean_text(section.text)
|
||||
|
||||
section_link_text = section.link or ""
|
||||
# If there is no useful content, not even the title, just drop it
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
@@ -287,10 +288,16 @@ class Chunker:
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
# In the case where the whole section is shorter than a chunk, either add
|
||||
# to chunk or start a new one
|
||||
# If the next section is from a page attachment a new chunk is started
|
||||
# (to ensure that especially image summaries are stored in separate chunks.)
|
||||
prefix = "## Text representation of attachment "
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if (
|
||||
next_section_tokens + current_token_count <= content_token_limit
|
||||
and prefix not in section_text
|
||||
):
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
|
||||
@@ -31,15 +31,14 @@ from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -358,6 +357,7 @@ def index_doc_batch(
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
logger.debug("Filtering Documents")
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
@@ -380,15 +380,6 @@ def index_doc_batch(
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
)
|
||||
|
||||
doc_descriptors = [
|
||||
{
|
||||
"doc_id": doc.id,
|
||||
"doc_length": doc.get_total_char_length(),
|
||||
}
|
||||
for doc in ctx.updatable_docs
|
||||
]
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
@@ -536,8 +527,7 @@ def build_indexing_pipeline(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
multipass_config = get_multipass_config(db_session, primary_index=True)
|
||||
|
||||
chunker = chunker or Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
|
||||
@@ -55,7 +55,9 @@ class DocAwareChunk(BaseChunk):
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
|
||||
return (
|
||||
f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}"
|
||||
)
|
||||
|
||||
|
||||
class IndexChunk(DocAwareChunk):
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -28,28 +28,30 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
self.tenant_id = get_current_tenant_id()
|
||||
|
||||
def __init__(
|
||||
self, redis_client: Redis | None = None, tenant_id: str | None = None
|
||||
) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@contextmanager
|
||||
def _get_session(self) -> Iterator[Session]:
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
@@ -64,7 +66,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
with self._get_session() as session:
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
@@ -86,7 +88,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self._get_session() as session:
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if not obj:
|
||||
raise KvKeyNotFoundError
|
||||
@@ -111,7 +113,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self._get_session() as session:
|
||||
with self.get_session() as session:
|
||||
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise KvKeyNotFoundError
|
||||
|
||||
@@ -26,7 +26,6 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
@@ -373,6 +372,10 @@ class DefaultMultiLLM(LLM):
|
||||
# # Return the lesser of available tokens or configured max
|
||||
# return min(self._max_output_tokens, available_output_tokens)
|
||||
|
||||
def vision_support(self) -> bool | None:
|
||||
model = f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
return litellm.supports_vision(model=model)
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
@@ -388,7 +391,6 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
|
||||
@@ -84,6 +84,9 @@ class LLM(abc.ABC):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
def vision_support(self) -> bool | None:
|
||||
return None
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
@@ -142,7 +142,6 @@ def build_content_with_imgs(
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
exclude_images: bool = False,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
|
||||
@@ -158,7 +157,7 @@ def build_content_with_imgs(
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if exclude_images or (not img_files and not img_urls):
|
||||
if not img_files and not img_urls:
|
||||
return message_main_content
|
||||
|
||||
return cast(
|
||||
@@ -383,19 +382,9 @@ def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
|
||||
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
|
||||
|
||||
model_names = [
|
||||
model_name,
|
||||
_strip_extra_provider_from_model_name(model_name),
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(stripped_model_name),
|
||||
]
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
@@ -428,10 +417,21 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
@@ -523,23 +523,3 @@ def get_max_input_tokens(
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
return model_obj.get("supports_vision", False)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -109,9 +109,7 @@ from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -214,7 +212,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
get_or_generate_uuid()
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
@@ -13,7 +11,6 @@ from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
from onyx.configs.app_configs import SKIP_WARM_UP
|
||||
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
@@ -158,7 +155,6 @@ class EmbeddingModel:
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
|
||||
) -> list[Embedding]:
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
|
||||
@@ -167,14 +163,12 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
embeddings: list[Embedding] = []
|
||||
|
||||
def process_batch(
|
||||
batch_idx: int, text_batch: list[str]
|
||||
) -> tuple[int, list[Embedding]]:
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
@@ -190,52 +184,11 @@ class EmbeddingModel:
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
response = self._make_model_server_request(embed_request)
|
||||
end_time = time.time()
|
||||
|
||||
processing_time = end_time - start_time
|
||||
logger.info(
|
||||
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
|
||||
)
|
||||
|
||||
return batch_idx, response.embeddings
|
||||
|
||||
# only multi thread if:
|
||||
# 1. num_threads is greater than 1
|
||||
# 2. we are using an API-based embedding model (provider_type is not None)
|
||||
# 3. there are more than 1 batch (no point in threading if only 1)
|
||||
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
future_to_batch = {
|
||||
executor.submit(process_batch, idx, batch): idx
|
||||
for idx, batch in enumerate(text_batches, start=1)
|
||||
}
|
||||
|
||||
# Collect results in order
|
||||
batch_results: list[tuple[int, list[Embedding]]] = []
|
||||
for future in as_completed(future_to_batch):
|
||||
try:
|
||||
result = future.result()
|
||||
batch_results.append(result)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
except Exception as e:
|
||||
logger.exception("Embedding model failed to process batch")
|
||||
raise e
|
||||
|
||||
# Sort by batch index and extend embeddings
|
||||
batch_results.sort(key=lambda x: x[0])
|
||||
for _, batch_embeddings in batch_results:
|
||||
embeddings.extend(batch_embeddings)
|
||||
else:
|
||||
# Original sequential processing
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
_, batch_embeddings = process_batch(idx, text_batch)
|
||||
embeddings.extend(batch_embeddings)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import Set
|
||||
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
@@ -123,9 +122,6 @@ class SlackbotHandler:
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
self._shutdown_event = Event()
|
||||
@@ -163,15 +159,10 @@ class SlackbotHandler:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.acquire_tenants()
|
||||
|
||||
# After we finish acquiring and managing Slack bots,
|
||||
# set the gauge to the number of active tenants (those with Slack bots).
|
||||
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
|
||||
len(self.tenant_ids)
|
||||
)
|
||||
logger.debug(
|
||||
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
|
||||
)
|
||||
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in Slack acquisition: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
|
||||
@@ -180,9 +171,7 @@ class SlackbotHandler:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.send_heartbeats()
|
||||
logger.debug(
|
||||
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
|
||||
)
|
||||
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in heartbeat loop: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
@@ -190,21 +179,17 @@ class SlackbotHandler:
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
- If the tokens have changed, close the existing socket client and reconnect.
|
||||
- If the tokens are new, warm up the model and start a new socket client.
|
||||
"""
|
||||
slack_bot_tokens = SlackBotTokens(
|
||||
bot_token=bot.bot_token,
|
||||
app_token=bot.app_token,
|
||||
)
|
||||
tenant_bot_pair = (tenant_id, bot.id)
|
||||
|
||||
# If the tokens are missing or empty, close the socket client and remove them.
|
||||
# If the tokens are not set, we need to close the socket client and delete the tokens
|
||||
# for the tenant and app
|
||||
if not slack_bot_tokens:
|
||||
logger.debug(
|
||||
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
|
||||
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
|
||||
)
|
||||
if tenant_bot_pair in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_bot_pair].close())
|
||||
@@ -219,10 +204,9 @@ class SlackbotHandler:
|
||||
if not tokens_exist or tokens_changed:
|
||||
if tokens_exist:
|
||||
logger.info(
|
||||
f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
# Warm up the model if needed
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
@@ -233,168 +217,77 @@ class SlackbotHandler:
|
||||
|
||||
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
|
||||
|
||||
# Close any existing connection first
|
||||
if tenant_bot_pair in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_bot_pair].close())
|
||||
|
||||
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
"""
|
||||
- Attempt to acquire a Redis lock for each tenant.
|
||||
- If acquired, check if that tenant actually has Slack bots.
|
||||
- If yes, store them in self.tenant_ids and manage the socket connections.
|
||||
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
|
||||
"""
|
||||
all_tenants = get_all_tenant_ids()
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# 1) Try to acquire locks for new tenants
|
||||
for tenant_id in all_tenants:
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
):
|
||||
logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
|
||||
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
|
||||
continue
|
||||
|
||||
# Already acquired in a previous loop iteration?
|
||||
if tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
|
||||
continue
|
||||
|
||||
# Respect max tenant limit per pod
|
||||
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
|
||||
logger.info(
|
||||
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
|
||||
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
|
||||
)
|
||||
break
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
# Acquire a Redis lock (non-blocking)
|
||||
rlock = redis_client.lock(
|
||||
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
|
||||
pod_id = self.pod_id
|
||||
acquired = redis_client.set(
|
||||
OnyxRedisLocks.SLACK_BOT_LOCK,
|
||||
pod_id,
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
lock_acquired = rlock.acquire(blocking=False)
|
||||
|
||||
if not lock_acquired and not DEV_MODE:
|
||||
logger.debug(
|
||||
f"Another pod holds the lock for tenant {tenant_id}, skipping."
|
||||
)
|
||||
if not acquired and not DEV_MODE:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
if lock_acquired:
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}.")
|
||||
self.redis_locks[tenant_id] = rlock
|
||||
else:
|
||||
# DEV_MODE will skip the lock acquisition guard
|
||||
logger.debug(
|
||||
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
|
||||
)
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
|
||||
# Now check if this tenant actually has Slack bots
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
bots: list[SlackBot] = []
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
except KvKeyNotFoundError:
|
||||
# No Slackbot tokens, pass
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
if bots:
|
||||
# Mark as active tenant
|
||||
self.tenant_ids.add(tenant_id)
|
||||
bots = fetch_slack_bots(db_session=db_session)
|
||||
for bot in bots:
|
||||
self._manage_clients_per_tenant(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
bot=bot,
|
||||
)
|
||||
else:
|
||||
# If no Slack bots, release lock immediately (unless in DEV_MODE)
|
||||
if lock_acquired and not DEV_MODE:
|
||||
rlock.release()
|
||||
del self.redis_locks[tenant_id]
|
||||
logger.debug(
|
||||
f"No Slack bots for tenant {tenant_id}; lock released (if held)."
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
# 2) Make sure tenants we're handling still have Slack bots
|
||||
for tenant_id in list(self.tenant_ids):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Attempt to fetch Slack bots
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
except KvKeyNotFoundError:
|
||||
# No Slackbot tokens, pass (and remove below)
|
||||
bots = []
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if (tenant_id, bot.id) in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
|
||||
del self.socket_clients[tenant_id, bot.id]
|
||||
del self.slack_bot_tokens[tenant_id, bot.id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling tenant {tenant_id}: {e}")
|
||||
bots = []
|
||||
|
||||
if not bots:
|
||||
logger.info(
|
||||
f"Tenant {tenant_id} no longer has Slack bots. Removing."
|
||||
)
|
||||
self._remove_tenant(tenant_id)
|
||||
|
||||
# NOTE: We release the lock here (in the same scope it was acquired)
|
||||
if tenant_id in self.redis_locks and not DEV_MODE:
|
||||
try:
|
||||
self.redis_locks[tenant_id].release()
|
||||
del self.redis_locks[tenant_id]
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error releasing lock for tenant {tenant_id}: {e}"
|
||||
)
|
||||
else:
|
||||
# Manage or reconnect Slack bot sockets
|
||||
for bot in bots:
|
||||
self._manage_clients_per_tenant(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
bot=bot,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
"""
|
||||
# Close all socket clients for this tenant
|
||||
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
|
||||
if t_id == tenant_id:
|
||||
asyncio.run(client.close())
|
||||
del self.socket_clients[(t_id, slack_bot_id)]
|
||||
del self.slack_bot_tokens[(t_id, slack_bot_id)]
|
||||
logger.info(
|
||||
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
|
||||
)
|
||||
|
||||
# Remove from active set
|
||||
if tenant_id in self.tenant_ids:
|
||||
self.tenant_ids.remove(tenant_id)
|
||||
|
||||
def send_heartbeats(self) -> None:
|
||||
current_time = int(time.time())
|
||||
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
|
||||
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
|
||||
@@ -422,7 +315,6 @@ class SlackbotHandler:
|
||||
)
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id, slack_bot_id] = socket_client
|
||||
# Ensure tenant is tracked as active
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.info(
|
||||
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
@@ -430,7 +322,7 @@ class SlackbotHandler:
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
|
||||
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
|
||||
asyncio.run(client.close())
|
||||
logger.info(
|
||||
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
@@ -448,19 +340,17 @@ class SlackbotHandler:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
|
||||
# Release locks for all tenants we currently hold
|
||||
# Release locks for all tenants
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in list(self.tenant_ids):
|
||||
if tenant_id in self.redis_locks:
|
||||
try:
|
||||
self.redis_locks[tenant_id].release()
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
del self.redis_locks[tenant_id]
|
||||
for tenant_id in self.tenant_ids:
|
||||
try:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK)
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
|
||||
# Wait for background threads to finish (with a timeout)
|
||||
# Wait for background threads to finish (with timeout)
|
||||
logger.info("Waiting for background threads to finish...")
|
||||
self.acquire_thread.join(timeout=5)
|
||||
self.heartbeat_thread.join(timeout=5)
|
||||
@@ -537,36 +427,30 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
# Let the tag flow handle this case, don't reply twice
|
||||
return False
|
||||
|
||||
# Check if this is a bot message (either via bot_profile or bot_message subtype)
|
||||
is_bot_message = bool(
|
||||
event.get("bot_profile") or event.get("subtype") == "bot_message"
|
||||
)
|
||||
if is_bot_message:
|
||||
if event.get("bot_profile"):
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
||||
if (not bot_tag_id or bot_tag_id not in msg) and (
|
||||
not slack_channel_config
|
||||
or not slack_channel_config.channel_config.get("respond_to_bots")
|
||||
):
|
||||
channel_specific_logger.info(
|
||||
"Ignoring message from bot since respond_to_bots is disabled"
|
||||
)
|
||||
channel_specific_logger.info("Ignoring message from bot")
|
||||
return False
|
||||
|
||||
# Ignore things like channel_join, channel_leave, etc.
|
||||
# NOTE: "file_share" is just a message with a file attachment, so we
|
||||
# should not ignore it
|
||||
message_subtype = event.get("subtype")
|
||||
if message_subtype not in [None, "file_share", "bot_message"]:
|
||||
if message_subtype not in [None, "file_share"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
|
||||
)
|
||||
|
||||
@@ -19,8 +19,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DANSWER_DATETIME_REPLACEMENT_PAT = "[[CURRENT_DATETIME]]"
|
||||
_BASIC_TIME_STR = "The current date is {datetime_info}."
|
||||
MOST_BASIC_PROMPT = "You are a helpful AI assistant."
|
||||
DANSWER_DATETIME_REPLACEMENT = "DANSWER_DATETIME_REPLACEMENT"
|
||||
BASIC_TIME_STR = "The current date is {datetime_info}."
|
||||
|
||||
|
||||
def get_current_llm_day_time(
|
||||
@@ -37,36 +38,23 @@ def get_current_llm_day_time(
|
||||
return f"{formatted_datetime}"
|
||||
|
||||
|
||||
def build_date_time_string() -> str:
|
||||
return ADDITIONAL_INFO.format(
|
||||
datetime_info=_BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
|
||||
)
|
||||
|
||||
|
||||
def handle_onyx_date_awareness(
|
||||
prompt_str: str,
|
||||
prompt_config: PromptConfig,
|
||||
add_additional_info_if_no_tag: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
If there is a [[CURRENT_DATETIME]] tag, replace it with the current date and time no matter what.
|
||||
If the prompt is datetime aware, and there are no [[CURRENT_DATETIME]] tags, add it to the prompt.
|
||||
do nothing otherwise.
|
||||
This can later be expanded to support other tags.
|
||||
"""
|
||||
|
||||
if _DANSWER_DATETIME_REPLACEMENT_PAT in prompt_str:
|
||||
def add_date_time_to_prompt(prompt_str: str) -> str:
|
||||
if DANSWER_DATETIME_REPLACEMENT in prompt_str:
|
||||
return prompt_str.replace(
|
||||
_DANSWER_DATETIME_REPLACEMENT_PAT,
|
||||
DANSWER_DATETIME_REPLACEMENT,
|
||||
get_current_llm_day_time(full_sentence=False, include_day_of_week=True),
|
||||
)
|
||||
any_tag_present = any(
|
||||
_DANSWER_DATETIME_REPLACEMENT_PAT in text
|
||||
for text in [prompt_str, prompt_config.system_prompt, prompt_config.task_prompt]
|
||||
)
|
||||
if add_additional_info_if_no_tag and not any_tag_present:
|
||||
return prompt_str + build_date_time_string()
|
||||
return prompt_str
|
||||
|
||||
if prompt_str:
|
||||
return prompt_str + ADDITIONAL_INFO.format(
|
||||
datetime_info=get_current_llm_day_time()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
MOST_BASIC_PROMPT
|
||||
+ " "
|
||||
+ BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
|
||||
)
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
|
||||
@@ -17,8 +17,6 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
@@ -43,12 +41,6 @@ class RedisConnectorPermissionSync:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -62,7 +54,6 @@ class RedisConnectorPermissionSync:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -116,20 +107,6 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
@@ -196,7 +173,6 @@ class RedisConnectorPermissionSync:
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -211,9 +187,6 @@ class RedisConnectorPermissionSync:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
@@ -136,12 +135,6 @@ class RedisConnectorExternalGroupSync:
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
||||
|
||||
@@ -30,17 +30,10 @@ class RedisConnectorIndex:
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
TERMINATE_TTL = 600
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
# it's difficult to prevent
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
# used to signal that the watchdog is running
|
||||
WATCHDOG_PREFIX = PREFIX + "_watchdog"
|
||||
WATCHDOG_TTL = 300
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -66,7 +59,6 @@ class RedisConnectorIndex:
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -118,24 +110,7 @@ class RedisConnectorIndex:
|
||||
"""This sets a signal. It does not block!"""
|
||||
# We shouldn't need very long to terminate the spawned task.
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(
|
||||
f"{self.terminate_key}_{celery_task_id}", 0, ex=self.TERMINATE_TTL
|
||||
)
|
||||
|
||||
def set_watchdog(self, value: bool) -> None:
|
||||
"""Signal the state of the watchdog."""
|
||||
if not value:
|
||||
self.redis.delete(self.watchdog_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.watchdog_key, 0, ex=self.WATCHDOG_TTL)
|
||||
|
||||
def watchdog_signaled(self) -> bool:
|
||||
"""Check the state of the watchdog."""
|
||||
if self.redis.exists(self.watchdog_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -143,7 +118,7 @@ class RedisConnectorIndex:
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
self.redis.set(self.active_key, 0, ex=3600)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user