mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
81 Commits
random_doc
...
potential_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48998a5204 | ||
|
|
697f8bc1c6 | ||
|
|
3ba65214b8 | ||
|
|
6687d5d499 | ||
|
|
ec78f78f3c | ||
|
|
ed253e469a | ||
|
|
e3aafd95af | ||
|
|
3a704f1950 | ||
|
|
2bf8a7aee5 | ||
|
|
c2f3302aa0 | ||
|
|
7f4d1f27a0 | ||
|
|
b70db15622 | ||
|
|
e9492ce9ec | ||
|
|
35574369ed | ||
|
|
eff433bdc5 | ||
|
|
3260d793d1 | ||
|
|
1a7aca06b9 | ||
|
|
c6434db7eb | ||
|
|
667b9e04c5 | ||
|
|
29c84d7707 | ||
|
|
17c915b11b | ||
|
|
95ca592d6d | ||
|
|
e39a27fd6b | ||
|
|
26d3c952c6 | ||
|
|
53683e2f3c | ||
|
|
0c0113a481 | ||
|
|
c0f381e471 | ||
|
|
5ed83f1148 | ||
|
|
9db7b67a6c | ||
|
|
2850048c6b | ||
|
|
61058e5fcd | ||
|
|
c87261cda7 | ||
|
|
e030b0a6fc | ||
|
|
61136975ad | ||
|
|
0c74bbf9ed | ||
|
|
12b2126e69 | ||
|
|
037943c6ff | ||
|
|
f9485b1325 | ||
|
|
552a0630fe | ||
|
|
5bf520d8b8 | ||
|
|
7dc5a77946 | ||
|
|
03abd4a1bc | ||
|
|
16d6d708f6 | ||
|
|
9740ed32b5 | ||
|
|
b56877cc2e | ||
|
|
da5c83a96d | ||
|
|
818225c60e | ||
|
|
d78a1fe9c6 | ||
|
|
05b3e594b5 | ||
|
|
5a4d007cf9 | ||
|
|
3b25a2dd84 | ||
|
|
baee4c5f22 | ||
|
|
5e32f9d922 | ||
|
|
1454e7e07d | ||
|
|
6848337445 | ||
|
|
519fbd897e | ||
|
|
217569104b | ||
|
|
4c184bb7f0 | ||
|
|
a222fae7c8 | ||
|
|
94788cda53 | ||
|
|
fb931ee4de | ||
|
|
bc2c56dfb6 | ||
|
|
ae37f01f62 | ||
|
|
ef31e14518 | ||
|
|
9b0cba367e | ||
|
|
48ac690a70 | ||
|
|
bfa4fbd691 | ||
|
|
58fdc86d41 | ||
|
|
6ff452a2e1 | ||
|
|
e9b892301b | ||
|
|
a202e2bf9d | ||
|
|
3bc4e0d12f | ||
|
|
2fc41cd5df | ||
|
|
8c42ff2ff8 | ||
|
|
6ccb3f085a | ||
|
|
a0a1b431be | ||
|
|
f137fc78a6 | ||
|
|
396f096dda | ||
|
|
e04b2d6ff3 | ||
|
|
cbd8b094bd | ||
|
|
5c7487e91f |
@@ -65,8 +65,10 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
|
||||
@@ -12,7 +12,32 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
@@ -52,6 +77,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
@@ -91,7 +118,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
needs: [build-amd64, build-arm64, check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
|
||||
40
.github/workflows/pr-integration-tests.yml
vendored
40
.github/workflows/pr-integration-tests.yml
vendored
@@ -94,16 +94,19 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
@@ -119,6 +122,10 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
@@ -126,17 +133,17 @@ jobs:
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -216,27 +223,30 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Stop Docker containers
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
@@ -44,6 +44,9 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -133,3 +133,4 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ RUN apt-get update && \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -101,7 +103,8 @@ COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Escape hatch
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
|
||||
# Put logo in assets
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""set built in to default
|
||||
|
||||
Revision ID: 2cdeff6d8c93
|
||||
Revises: f5437cc136c5
|
||||
Create Date: 2025-02-11 14:57:51.308775
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2cdeff6d8c93"
|
||||
down_revision = "f5437cc136c5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Prior to this migration / point in the codebase history,
|
||||
# built in personas were implicitly treated as default personas (with no option to change this)
|
||||
# This migration makes that explicit
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = TRUE
|
||||
WHERE builtin_persona = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -5,7 +5,6 @@ Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -20,13 +19,8 @@ down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
@@ -63,7 +57,6 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
@@ -71,15 +64,12 @@ def upgrade() -> None:
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
@@ -170,10 +160,9 @@ def upgrade() -> None:
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
pass
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
@@ -190,8 +179,6 @@ def upgrade() -> None:
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
@@ -273,7 +260,7 @@ def downgrade() -> None:
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
pass
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
|
||||
@@ -52,7 +52,11 @@ def upgrade() -> None:
|
||||
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
|
||||
) VALUES (
|
||||
:bot_id, NULL,
|
||||
'{"channel_name": null, "respond_member_group_list": [], "answer_filters": [], "follow_up_tags": []}',
|
||||
'{"channel_name": null, '
|
||||
'"respond_member_group_list": [], '
|
||||
'"answer_filters": [], '
|
||||
'"follow_up_tags": [], '
|
||||
'"respond_tag_only": true}',
|
||||
FALSE, TRUE
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add background errors table
|
||||
|
||||
Revision ID: f39c5794c10a
|
||||
Revises: 2cdeff6d8c93
|
||||
Create Date: 2025-02-12 17:11:14.527876
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f39c5794c10a"
|
||||
down_revision = "2cdeff6d8c93"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"background_error",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("message", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("background_error")
|
||||
@@ -0,0 +1,53 @@
|
||||
"""delete non-search assistants
|
||||
|
||||
Revision ID: f5437cc136c5
|
||||
Revises: eaa3b5593925
|
||||
Create Date: 2025-02-04 16:17:15.677256
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f5437cc136c5"
|
||||
down_revision = "eaa3b5593925"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Fix: split the statements into multiple op.execute() calls
|
||||
op.execute(
|
||||
"""
|
||||
WITH personas_without_search AS (
|
||||
SELECT p.id
|
||||
FROM persona p
|
||||
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
|
||||
LEFT JOIN tool t ON pt.tool_id = t.id
|
||||
GROUP BY p.id
|
||||
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
|
||||
)
|
||||
UPDATE slack_channel_config
|
||||
SET persona_id = NULL
|
||||
WHERE is_default = TRUE AND persona_id IN (SELECT id FROM personas_without_search)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
WITH personas_without_search AS (
|
||||
SELECT p.id
|
||||
FROM persona p
|
||||
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
|
||||
LEFT JOIN tool t ON pt.tool_id = t.id
|
||||
GROUP BY p.id
|
||||
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
|
||||
)
|
||||
DELETE FROM slack_channel_config
|
||||
WHERE is_default = FALSE AND persona_id IN (SELECT id FROM personas_without_search)
|
||||
"""
|
||||
)
|
||||
@@ -1,44 +1,46 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_cloud_tasks as base_beat_system_tasks,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
beat_task_templates as base_beat_task_templates,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
get_tasks_to_schedule as base_get_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
@@ -65,9 +67,14 @@ if not MULTI_TENANT:
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
|
||||
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
|
||||
cloud_tasks = generate_cloud_tasks(
|
||||
beat_system_tasks, beat_task_templates, beat_multiplier
|
||||
)
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
|
||||
|
||||
@@ -77,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
@@ -2,8 +2,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
@@ -12,6 +15,9 @@ def make_persona_private(
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -20,11 +26,22 @@ def make_persona_private(
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
for user_uuid in user_ids:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
for group_id in group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
@@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def construct_document_select_by_usergroup(
|
||||
def construct_document_id_select_by_usergroup(
|
||||
user_group_id: int,
|
||||
) -> Select:
|
||||
"""This returns a statement that should be executed using
|
||||
.yield_per() to minimize overhead. The primary consumers of this function
|
||||
are background processing task generators."""
|
||||
stmt = (
|
||||
select(Document)
|
||||
select(Document.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
|
||||
@@ -365,7 +365,9 @@ def confluence_doc_sync(
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
@@ -10,7 +11,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
@@ -18,8 +19,11 @@ def _build_group_member_email_map(
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
@@ -32,7 +36,12 @@ def _build_group_member_email_map(
|
||||
)
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
@@ -42,11 +51,18 @@ def _build_group_member_email_map(
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
if not all_users_groups:
|
||||
msg = f"No groups found for user with email: {email}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
if not group_member_emails:
|
||||
msg = "No groups found for any users."
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -61,6 +77,7 @@ def confluence_group_sync(
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
|
||||
@@ -15,6 +15,7 @@ logger = setup_logger()
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
gmail_connector: GmailConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -24,7 +25,9 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +43,9 @@ def gmail_doc_sync(
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
|
||||
@@ -21,6 +21,7 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -30,7 +31,9 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,19 +20,11 @@ def _get_slack_document_ids_and_channels(
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -40,6 +32,14 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
@@ -87,13 +88,14 @@ async def _get_tenant_id_from_request(
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
finally:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# As a final step, check for explicit tenant_id cookie
|
||||
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
|
||||
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
|
||||
|
||||
@@ -83,6 +83,7 @@ def handle_search_request(
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=False,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
@@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
@@ -39,12 +44,9 @@ from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -126,37 +128,29 @@ async def login_as_anonymous_user(
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> None:
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when
|
||||
1) User has ended free trial without adding payment method
|
||||
2) User's card has declined
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
tenant_id = product_gating_request.tenant_id
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
store_settings(settings)
|
||||
|
||||
if product_gating_request.notification:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_notification(None, product_gating_request.notification, db_session)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information", response_model=BillingInformation)
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation:
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
return BillingInformation(
|
||||
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
)
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
@@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
|
||||
@@ -6,6 +6,7 @@ import stripe
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
|
||||
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> dict:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = response.json()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.server.settings.models import GatingType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class CheckoutSessionCreationRequest(BaseModel):
|
||||
@@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
|
||||
|
||||
class ProductGatingRequest(BaseModel):
|
||||
tenant_id: str
|
||||
product_gating: GatingType
|
||||
notification: NotificationType | None = None
|
||||
application_status: ApplicationStatus
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
subscribed: bool
|
||||
|
||||
|
||||
class BillingInformation(BaseModel):
|
||||
stripe_subscription_id: str
|
||||
status: str
|
||||
current_period_start: datetime
|
||||
current_period_end: datetime
|
||||
number_of_seats: int
|
||||
cancel_at_period_end: bool
|
||||
canceled_at: datetime | None
|
||||
trial_start: datetime | None
|
||||
trial_end: datetime | None
|
||||
seats: int
|
||||
subscription_status: str
|
||||
billing_start: str
|
||||
billing_end: str
|
||||
payment_method_enabled: bool
|
||||
|
||||
|
||||
@@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
|
||||
|
||||
class AnonymousUserPath(BaseModel):
|
||||
anonymous_user_path: str | None
|
||||
|
||||
|
||||
class ProductGatingResponse(BaseModel):
|
||||
updated: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import cast
|
||||
|
||||
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
# Store the full status
|
||||
status_key = f"tenant:{tenant_id}:status"
|
||||
redis_client.set(status_key, status.value)
|
||||
|
||||
# Maintain the GATED_ACCESS set
|
||||
if status == ApplicationStatus.GATED_ACCESS:
|
||||
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
|
||||
else:
|
||||
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
|
||||
|
||||
|
||||
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
settings.application_status = application_status
|
||||
store_settings(settings)
|
||||
|
||||
# Store gated tenant information in Redis
|
||||
update_tenant_gating(tenant_id, application_status)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to gate product")
|
||||
raise
|
||||
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
@@ -24,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@@ -85,7 +86,8 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
|
||||
@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
CUDA = "cuda"
|
||||
MAC_MPS = "mps"
|
||||
NONE = "none"
|
||||
|
||||
@@ -12,6 +12,7 @@ import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
@@ -320,6 +321,7 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
@@ -373,8 +375,11 @@ async def embed_text(
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.info(
|
||||
@@ -403,6 +408,14 @@ async def embed_text(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"event=embedding_model "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"model={model_name} "
|
||||
f"gpu={gpu_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
@@ -455,8 +468,15 @@ async def litellm_rerank(
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
return await process_embed_request(embed_request, request.app.state.gpu_type)
|
||||
|
||||
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
@@ -484,6 +504,7 @@ async def process_embed_request(
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
|
||||
@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if torch.cuda.is_available():
|
||||
logger.notice("CUDA GPU is available")
|
||||
elif torch.backends.mps.is_available():
|
||||
logger.notice("Mac MPS is available")
|
||||
else:
|
||||
logger.notice("GPU is not available, using CPU")
|
||||
gpu_type = get_gpu_type()
|
||||
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Response
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from model_server.utils import get_gpu_type
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
return {"gpu_available": True, "type": "mps"}
|
||||
else:
|
||||
return {"gpu_available": False, "type": "none"}
|
||||
async def route_gpu_status() -> dict[str, bool | str]:
|
||||
gpu_type = get_gpu_type()
|
||||
gpu_available = gpu_type != GPUStatus.NONE
|
||||
return {"gpu_available": gpu_available, "type": gpu_type}
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -58,3 +61,12 @@ def simple_log_function_time(
|
||||
return cast(F, wrapped_sync_func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_gpu_type() -> str:
|
||||
if torch.cuda.is_available():
|
||||
return GPUStatus.CUDA
|
||||
if torch.backends.mps.is_available():
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
@@ -85,7 +85,7 @@ if __name__ == "__main__":
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(_unused=True)
|
||||
input = BasicInput(unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
_unused: bool = True
|
||||
unused: bool = True
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
@@ -9,7 +9,6 @@ class CoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
@@ -18,4 +17,4 @@ class SubgraphCoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
@@ -12,14 +12,43 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
|
||||
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
|
||||
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def check_sub_answer(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> SubQuestionAnswerCheckUpdate:
|
||||
@@ -53,14 +82,40 @@ def check_sub_answer(
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_quality = "yes" in quality_str.lower()
|
||||
quality_str: str = cast(str, response.content)
|
||||
answer_quality = binary_string_test(
|
||||
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Timeout Error - check sub answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Rate Limit Error - check sub answer")
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
@@ -69,7 +124,7 @@ def check_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Answer quality: {quality_str}",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -16,6 +16,23 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
dedup_sort_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -30,12 +47,23 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
config: RunnableConfig,
|
||||
@@ -51,12 +79,17 @@ def generate_sub_answer(
|
||||
state.verified_reranked_documents
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
|
||||
context_docs = dedup_sort_inference_section_list(context_docs)
|
||||
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
graph_config.inputs.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
cited_documents: list = []
|
||||
log_results = "No documents retrieved"
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
@@ -79,41 +112,67 @@ def generate_sub_answer(
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
logger.debug(
|
||||
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate sub answer")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate sub answer")
|
||||
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
if agent_error:
|
||||
answer_str = LLM_ANSWER_ERROR_MESSAGE
|
||||
cited_documents = []
|
||||
log_results = (
|
||||
agent_error.error_result
|
||||
or "Sub-answer generation failed due to LLM error"
|
||||
)
|
||||
|
||||
else:
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
log_results = None
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
@@ -131,7 +190,7 @@ def generate_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
result=log_results or "",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -42,10 +42,8 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
question: str
|
||||
question_id: str
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
@@ -26,14 +26,31 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_deduplicated_structured_subquestion_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
@@ -42,12 +59,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
@@ -56,8 +77,16 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The initial answer could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
|
||||
general_error="General LLM Error. The initial answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
@@ -73,15 +102,19 @@ def generate_initial_answer(
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
# get all documents cited in sub-questions
|
||||
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
|
||||
state.sub_question_results
|
||||
)
|
||||
|
||||
orig_question_retrieval_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
consolidated_context_docs = structured_subquestion_docs.cited_documents
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
orig_question_retrieval_documents
|
||||
):
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if original_doc_number not in structured_subquestion_docs.cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
@@ -90,15 +123,18 @@ def generate_initial_answer(
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
|
||||
|
||||
sub_questions: list[str] = []
|
||||
streamed_documents = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else state.orig_question_retrieved_documents[:15]
|
||||
|
||||
# Create the list of documents to stream out. Start with the
|
||||
# ones that wil be in the context (or, if len == 0, use docs
|
||||
# that were retrieved for the original question)
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=relevant_docs,
|
||||
context_documents=structured_subquestion_docs.context_documents,
|
||||
original_question_docs=orig_question_retrieval_documents,
|
||||
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
|
||||
)
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
@@ -108,11 +144,13 @@ def generate_initial_answer(
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=streamed_documents,
|
||||
final_context_sections=streamed_documents,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -128,7 +166,7 @@ def generate_initial_answer(
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
if len(answer_generation_documents.context_documents) == 0:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
@@ -194,7 +232,7 @@ def generate_initial_answer(
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
config=model.config,
|
||||
prompt_piece=doc_context,
|
||||
@@ -224,30 +262,82 @@ def generate_initial_answer(
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate initial answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate initial answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=None,
|
||||
answer_error=AgentErrorLog(
|
||||
error_message=agent_error.error_message or "An LLM error occurred",
|
||||
error_type=agent_error.error_type,
|
||||
error_result=agent_error.error_result,
|
||||
),
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=None,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
|
||||
@@ -10,8 +10,10 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def validate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> InitialAnswerQualityUpdate:
|
||||
@@ -25,7 +27,7 @@ def validate_initial_answer(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
verdict = True # not actually required as already streamed out. Refinement will do similar
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
|
||||
@@ -12,8 +12,9 @@ from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion,
|
||||
dispatch_subquestion_sep,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
@@ -22,6 +23,8 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -32,17 +35,30 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. Sub-questions could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
|
||||
general_error="General LLM Error. Sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def decompose_orig_question(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
@@ -84,15 +100,15 @@ def decompose_orig_question(
|
||||
]
|
||||
)
|
||||
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question, history=history
|
||||
decomposition_prompt = (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
@@ -109,31 +125,44 @@ def decompose_orig_question(
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(0, writer)
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
decomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
list_of_subqs = cast(str, decomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except LLMTimeoutError as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
logger.error("LLM Rate Limit Error - decompose orig question")
|
||||
raise e
|
||||
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=decomp_list,
|
||||
initial_sub_questions=initial_sub_questions,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
@@ -147,7 +176,7 @@ def decompose_orig_question(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="decompose original question",
|
||||
node_start_time=node_start_time,
|
||||
result=f"decomposed original question into {len(decomp_list)} subquestions",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -26,8 +26,8 @@ from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need impo
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
|
||||
extract_entities_terms,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.generate_validate_refined_answer import (
|
||||
generate_validate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
|
||||
ingest_refined_sub_answers,
|
||||
@@ -126,8 +126,8 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
|
||||
# Node to generate the refined answer
|
||||
graph.add_node(
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
node="generate_validate_refined_answer",
|
||||
action=generate_validate_refined_answer,
|
||||
)
|
||||
|
||||
# Early node to extract the entities and terms from the initial answer,
|
||||
@@ -215,11 +215,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_sub_answers",
|
||||
end_key="generate_refined_answer",
|
||||
end_key="generate_validate_refined_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_refined_answer",
|
||||
start_key="generate_validate_refined_answer",
|
||||
end_key="compare_answers",
|
||||
)
|
||||
graph.add_edge(
|
||||
@@ -252,9 +252,7 @@ if __name__ == "__main__":
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
inputs = MainInput(log_messages=[])
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -10,16 +11,51 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out, and the answers could not be compared.",
|
||||
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
|
||||
general_error="The LLM encountered an error, and the answers could not be compared.",
|
||||
)
|
||||
|
||||
_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE = (
|
||||
"Answer quality is not sufficient, so stay with the initial answer."
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def compare_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> InitialRefinedAnswerComparisonUpdate:
|
||||
@@ -34,21 +70,75 @@ def compare_answers(
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
# if answer quality is not sufficient, then stay with the initial answer
|
||||
if not state.refined_answer_quality:
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return InitialRefinedAnswerComparisonUpdate(
|
||||
refined_answer_improvement_eval=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
resp: BaseMessage | None = None
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
try:
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - compare answers")
|
||||
# continue as True in this support step
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - compare answers")
|
||||
# continue as True in this support step
|
||||
|
||||
if agent_error or resp is None:
|
||||
refined_answer_improvement = True
|
||||
if agent_error:
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
log_result = "An answer could not be generated."
|
||||
|
||||
else:
|
||||
refined_answer_improvement = binary_string_test(
|
||||
text=cast(str, resp.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
log_result = f"Answer comparison: {refined_answer_improvement}"
|
||||
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
@@ -65,7 +155,7 @@ def compare_answers(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Answer comparison: {refined_answer_improvement}",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -9,8 +9,9 @@ from langgraph.types import StreamWriter
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
RefinementSubQuestion,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import dispatch_subquestion
|
||||
from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
dispatch_subquestion,
|
||||
dispatch_subquestion_sep,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
@@ -20,6 +21,18 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
@@ -29,12 +42,31 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS,
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ANSWERED_SUBQUESTIONS_DIVIDER = "\n\n---\n\n"
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The sub-questions could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
|
||||
general_error="The LLM encountered an error. The sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedQuestionDecompositionUpdate:
|
||||
@@ -71,8 +103,10 @@ def create_refined_sub_questions(
|
||||
|
||||
initial_question_answers = state.sub_question_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if x.verified_high_quality
|
||||
addressed_subquestions_with_answers = [
|
||||
f"Subquestion: {x.question}\nSubanswer:\n{x.answer}"
|
||||
for x in initial_question_answers
|
||||
if x.verified_high_quality and x.answer
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
@@ -81,12 +115,14 @@ def create_refined_sub_questions(
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
answered_subquestions_with_answers=_ANSWERED_SUBQUESTIONS_DIVIDER.join(
|
||||
addressed_subquestions_with_answers
|
||||
),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
@@ -95,27 +131,65 @@ def create_refined_sub_questions(
|
||||
# Grader
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(1, writer)
|
||||
)
|
||||
response = merge_content(*streamed_tokens)
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - create refined sub questions")
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - create refined sub questions")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
if agent_error:
|
||||
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
|
||||
log_result = agent_error.error_result
|
||||
write_custom_event(
|
||||
"refined_sub_question_creation_error",
|
||||
StreamingError(
|
||||
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
else:
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
|
||||
|
||||
return RefinedQuestionDecompositionUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
@@ -125,7 +199,7 @@ def create_refined_sub_questions(
|
||||
graph_component="main",
|
||||
node_name="create refined sub questions",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -11,8 +11,10 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def decide_refinement_need(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinemenEvalUpdate:
|
||||
@@ -26,6 +28,19 @@ def decide_refinement_need(
|
||||
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
if state.answer_error:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="decide refinement need",
|
||||
node_start_time=node_start_time,
|
||||
result="Timeout Error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
log_messages = [
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -21,11 +21,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def extract_entities_terms(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
@@ -81,6 +86,7 @@ def extract_entities_terms(
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
|
||||
@@ -11,27 +11,49 @@ from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test_after_answer_separator,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AGENT_ANSWER_SEPARATOR
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_deduplicated_structured_subquestion_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
@@ -43,26 +65,50 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_VALIDATION_PROMPT,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The refined answer could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
|
||||
general_error="The LLM encountered an error. The refined answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
@log_function_time(print_only=True)
|
||||
def generate_validate_refined_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to generate the refined answer.
|
||||
LangGraph node to generate the refined answer and validate it.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
@@ -76,19 +122,24 @@ def generate_refined_answer(
|
||||
)
|
||||
|
||||
verified_reranked_documents = state.verified_reranked_documents
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
|
||||
# get all documents cited in sub-questions
|
||||
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
|
||||
state.sub_question_results
|
||||
)
|
||||
|
||||
original_question_verified_documents = (
|
||||
state.orig_question_verified_reranked_documents
|
||||
)
|
||||
original_question_retrieved_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
consolidated_context_docs = structured_subquestion_docs.cited_documents
|
||||
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
original_question_verified_documents
|
||||
):
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if original_doc_number not in structured_subquestion_docs.cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs)
|
||||
@@ -99,14 +150,16 @@ def generate_refined_answer(
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
|
||||
|
||||
streaming_docs = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else original_question_retrieved_documents[:15]
|
||||
# Create the list of documents to stream out. Start with the
|
||||
# ones that wil be in the context (or, if len == 0, use docs
|
||||
# that were retrieved for the original question)
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=relevant_docs,
|
||||
context_documents=structured_subquestion_docs.context_documents,
|
||||
original_question_docs=original_question_retrieved_documents,
|
||||
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER,
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
@@ -114,11 +167,13 @@ def generate_refined_answer(
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
# stream refined answer docs, or original question docs if no relevant docs are found
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=streaming_docs,
|
||||
final_context_sections=streaming_docs,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -199,7 +254,7 @@ def generate_refined_answer(
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(relevant_docs)
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs_str,
|
||||
@@ -231,28 +286,80 @@ def generate_refined_answer(
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate refined answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate refined answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=None,
|
||||
refined_answer_quality=False, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=0.0,
|
||||
refined_question_boost_factor=0.0,
|
||||
duration_s=None,
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate refined answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
@@ -261,54 +368,43 @@ def generate_refined_answer(
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
# run a validation step for the refined answer only
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REFINED_ANSWER_VALIDATION_PROMPT.format(
|
||||
question=question,
|
||||
history=prompt_enrichment_components.history,
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
relevant_docs=relevant_docs_str,
|
||||
proposed_answer=answer,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
validation_response = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Rate Limit Error - validate refined answer")
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=refined_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
@@ -325,7 +421,7 @@ def generate_refined_answer(
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_answer_quality=refined_answer_quality,
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
@@ -9,6 +9,9 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
SubQuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
@@ -34,6 +37,22 @@ def dispatch_subquestion(
|
||||
return _helper
|
||||
|
||||
|
||||
def dispatch_subquestion_sep(level: int, writer: StreamWriter) -> Callable[[int], None]:
|
||||
def _helper(sep_num: int) -> None:
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
level_question_num=sep_num,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
||||
|
||||
def calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[SubQuestionAnswerResults],
|
||||
original_question_stats: AgentChunkRetrievalStats,
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@@ -76,6 +77,7 @@ class InitialAnswerUpdate(LoggerUpdate):
|
||||
"""
|
||||
|
||||
initial_answer: str | None = None
|
||||
answer_error: AgentErrorLog | None = None
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
@@ -88,6 +90,7 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
"""
|
||||
|
||||
refined_answer: str | None = None
|
||||
answer_error: AgentErrorLog | None = None
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
@@ -16,16 +16,44 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
|
||||
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
|
||||
general_error="Query rewriting failed due to LLM error - the original question will be used.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput,
|
||||
config: RunnableConfig,
|
||||
@@ -54,13 +82,43 @@ def expand_queries(
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
|
||||
)
|
||||
agent_error: AgentErrorLog | None = None
|
||||
llm_response_list: list[BaseMessage_Content] = []
|
||||
llm_response = ""
|
||||
rewritten_queries = []
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
try:
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
|
||||
0
|
||||
].content
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - expand queries")
|
||||
log_result = agent_error.error_result
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - expand queries")
|
||||
log_result = agent_error.error_result
|
||||
# use subquestion as query if query generation fails
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
@@ -69,7 +127,7 @@ def expand_queries(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="expand queries",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Number of expanded queries: {len(rewritten_queries)}",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,12 +21,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.context.search.postprocessing.postprocessing import should_rerank
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def rerank_documents(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
@@ -39,6 +42,8 @@ def rerank_documents(
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
# If no question defined/question is None in the state, use the original
|
||||
# question from the search request as query
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
@@ -47,39 +52,28 @@ def rerank_documents(
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
with get_session_context_manager() as db_session:
|
||||
# we ignore some of the user specified fields since this search is
|
||||
# internal to agentic search, but we still want to pass through
|
||||
# persona (for stuff like document sets) and rerank settings
|
||||
# (to not make an unnecessary db call).
|
||||
search_request = SearchRequest(
|
||||
query=question,
|
||||
persona=graph_config.inputs.search_request.persona,
|
||||
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
||||
)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=search_request,
|
||||
user=graph_config.tooling.search_tool.user, # bit of a hack
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
and len(verified_documents) > 0
|
||||
):
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
num = "No" if len(verified_documents) == 0 else "One"
|
||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
|
||||
@@ -23,12 +23,15 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def retrieve_documents(
|
||||
state: RetrievalInput, config: RunnableConfig
|
||||
) -> DocRetrievalUpdate:
|
||||
@@ -67,9 +70,12 @@ def retrieve_documents(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=not state.base_search,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -10,14 +12,38 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
|
||||
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
|
||||
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def verify_documents(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
@@ -26,12 +52,14 @@ def verify_documents(
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
config (RunnableConfig): Configuration containing AgentSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state.question
|
||||
retrieved_document_to_verify = state.retrieved_document_to_verify
|
||||
document_content = retrieved_document_to_verify.combined_content
|
||||
@@ -51,12 +79,40 @@ def verify_documents(
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
response: BaseMessage | None = None
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
verified_documents = [
|
||||
retrieved_document_to_verify
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
if not binary_string_test(
|
||||
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
except LLMRateLimitError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Rate Limit Error - verify documents")
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="verify documents",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,9 +21,13 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
# exception from 'no default value'for LangGraph input states
|
||||
# Here, sub_question_id default None implies usage for the
|
||||
# original question. This is sometimes needed for nested sub-graphs
|
||||
|
||||
sub_question_id: str | None = None
|
||||
question: str
|
||||
base_search: bool
|
||||
|
||||
|
||||
## Update/Return States
|
||||
@@ -34,7 +38,7 @@ class QueryExpansionUpdate(LoggerUpdate, BaseModel):
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
class DocVerificationUpdate(LoggerUpdate, BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
@@ -88,4 +92,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
query_to_retrieve: str
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput_a,
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
@@ -21,6 +21,7 @@ from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
@@ -33,6 +34,7 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
@@ -72,13 +74,15 @@ def _parse_agent_event(
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
elif event["name"] == "refined_sub_question_creation_error":
|
||||
return cast(StreamingError, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
graph_input: BasicInput | MainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -92,7 +96,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
input: BasicInput | MainInput,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -123,9 +127,7 @@ def run_main_graph(
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
input = MainInput(log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
@@ -140,7 +142,7 @@ def run_basic_graph(
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput()
|
||||
input = BasicInput(unused=True)
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
@@ -172,9 +174,7 @@ if __name__ == "__main__":
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
input = MainInput(log_messages=[])
|
||||
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
AgentPromptEnrichmentComponents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
@@ -40,13 +41,7 @@ def build_sub_question_answer_prompt(
|
||||
|
||||
date_str = build_date_time_string()
|
||||
|
||||
# TODO: This should include document metadata and title
|
||||
docs_format_list = [
|
||||
f"Document Number: [D{doc_num + 1}]\nContent: {doc.combined_content}\n\n"
|
||||
for doc_num, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
docs_str = format_docs(docs)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config,
|
||||
@@ -150,3 +145,38 @@ def get_prompt_enrichment_components(
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
|
||||
|
||||
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
return positive_value.lower() in text.lower()
|
||||
|
||||
|
||||
def binary_string_test_after_answer_separator(
|
||||
text: str, positive_value: str = "yes", separator: str = "Answer:"
|
||||
) -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
|
||||
if separator not in text:
|
||||
return False
|
||||
relevant_text = text.split(f"{separator}")[-1]
|
||||
|
||||
return binary_string_test(relevant_text, positive_value)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -96,3 +100,106 @@ def get_fit_scores(
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
|
||||
|
||||
def get_answer_generation_documents(
|
||||
relevant_docs: list[InferenceSection],
|
||||
context_documents: list[InferenceSection],
|
||||
original_question_docs: list[InferenceSection],
|
||||
max_docs: int,
|
||||
) -> AnswerGenerationDocuments:
|
||||
"""
|
||||
Create a deduplicated list of documents to stream, prioritizing relevant docs.
|
||||
|
||||
Args:
|
||||
relevant_docs: Primary documents to include
|
||||
context_documents: Additional context documents to append
|
||||
original_question_docs: Original question documents to append
|
||||
max_docs: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of deduplicated documents, limited to max_docs
|
||||
"""
|
||||
# get relevant_doc ids
|
||||
relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
|
||||
|
||||
# Start with relevant docs or fallback to original question docs
|
||||
streaming_documents = relevant_docs.copy()
|
||||
|
||||
# Use a set for O(1) lookups of document IDs
|
||||
seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
|
||||
|
||||
# Combine additional documents to check in one iteration
|
||||
additional_docs = context_documents + original_question_docs
|
||||
for doc_idx, doc in enumerate(additional_docs):
|
||||
doc_id = doc.center_chunk.document_id
|
||||
if doc_id not in seen_doc_ids:
|
||||
streaming_documents.append(doc)
|
||||
seen_doc_ids.add(doc_id)
|
||||
|
||||
streaming_documents = dedup_inference_section_list(streaming_documents)
|
||||
|
||||
relevant_streaming_docs = [
|
||||
doc
|
||||
for doc in streaming_documents
|
||||
if doc.center_chunk.document_id in relevant_doc_ids
|
||||
]
|
||||
relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
|
||||
|
||||
additional_streaming_docs = [
|
||||
doc
|
||||
for doc in streaming_documents
|
||||
if doc.center_chunk.document_id not in relevant_doc_ids
|
||||
]
|
||||
additional_streaming_docs = dedup_sort_inference_section_list(
|
||||
additional_streaming_docs
|
||||
)
|
||||
|
||||
for doc in additional_streaming_docs:
|
||||
if doc.center_chunk.score:
|
||||
doc.center_chunk.score += -2.0
|
||||
else:
|
||||
doc.center_chunk.score = -2.0
|
||||
|
||||
sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
|
||||
|
||||
return AnswerGenerationDocuments(
|
||||
streaming_documents=sorted_streaming_documents[:max_docs],
|
||||
context_documents=relevant_streaming_docs[:max_docs],
|
||||
)
|
||||
|
||||
|
||||
def dedup_sort_inference_section_list(
|
||||
sections: list[InferenceSection],
|
||||
) -> list[InferenceSection]:
|
||||
"""Deduplicates InferenceSections by document_id and sorts by score.
|
||||
|
||||
Args:
|
||||
sections: List of InferenceSections to deduplicate and sort
|
||||
|
||||
Returns:
|
||||
Deduplicated list of InferenceSections sorted by score in descending order
|
||||
"""
|
||||
# dedupe/merge with existing framework
|
||||
sections = dedup_inference_section_list(sections)
|
||||
|
||||
# Use dict to deduplicate by document_id, keeping highest scored version
|
||||
unique_sections: dict[str, InferenceSection] = {}
|
||||
for section in sections:
|
||||
doc_id = section.center_chunk.document_id
|
||||
if doc_id not in unique_sections:
|
||||
unique_sections[doc_id] = section
|
||||
continue
|
||||
|
||||
# Keep version with higher score
|
||||
existing_score = unique_sections[doc_id].center_chunk.score or 0
|
||||
new_score = section.center_chunk.score or 0
|
||||
if new_score > existing_score:
|
||||
unique_sections[doc_id] = section
|
||||
|
||||
# Sort by score in descending order, handling None scores
|
||||
sorted_sections = sorted(
|
||||
unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
|
||||
)
|
||||
|
||||
return sorted_sections
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from enum import Enum
|
||||
|
||||
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
|
||||
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
|
||||
AGENT_LLM_RATELIMIT_MESSAGE = (
|
||||
"The agent encountered a rate limit error. Please try again."
|
||||
)
|
||||
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
|
||||
|
||||
AGENT_POSITIVE_VALUE_STR = "yes"
|
||||
AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
GENERAL_ERROR = "general_error"
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
@@ -56,6 +58,12 @@ class InitialAgentResultStats(BaseModel):
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class AgentErrorLog(BaseModel):
|
||||
error_message: str
|
||||
error_type: str
|
||||
error_result: str
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
@@ -110,6 +118,11 @@ class SubQuestionAnswerResults(BaseModel):
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats
|
||||
|
||||
|
||||
class StructuredSubquestionDocuments(BaseModel):
|
||||
cited_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
@@ -126,3 +139,17 @@ class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaPromptExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
|
||||
|
||||
class LLMNodeErrorStrings(BaseModel):
|
||||
timeout: str = "LLM Timeout Error"
|
||||
rate_limit: str = "LLM Rate Limit Error"
|
||||
general_error: str = "General LLM Error"
|
||||
|
||||
|
||||
class AnswerGenerationDocuments(BaseModel):
|
||||
streaming_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
@@ -12,6 +12,13 @@ def dedup_inference_sections(
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_inference_section_list(
|
||||
list: list[InferenceSection],
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[SubQuestionAnswerResults],
|
||||
question_answer_results_2: list[SubQuestionAnswerResults],
|
||||
|
||||
@@ -20,10 +20,18 @@ from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
StructuredSubquestionDocuments,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
@@ -34,6 +42,9 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -46,6 +57,8 @@ from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
@@ -58,6 +71,7 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
@@ -65,8 +79,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Post-processing
|
||||
@@ -218,7 +233,10 @@ def get_test_config(
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
chat_session_id = (
|
||||
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
or "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
assert (
|
||||
chat_session_id is not None
|
||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||
@@ -295,6 +313,7 @@ def _dispatch_nonempty(
|
||||
def dispatch_separated(
|
||||
tokens: Iterator[BaseMessage],
|
||||
dispatch_event: Callable[[str, int], None],
|
||||
sep_callback: Callable[[int], None] | None = None,
|
||||
sep: str = DISPATCH_SEP_CHAR,
|
||||
) -> list[BaseMessage_Content]:
|
||||
num = 1
|
||||
@@ -304,6 +323,10 @@ def dispatch_separated(
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
|
||||
|
||||
if sep_callback:
|
||||
sep_callback(num)
|
||||
|
||||
num += 1
|
||||
_dispatch_nonempty(
|
||||
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
|
||||
@@ -312,6 +335,9 @@ def dispatch_separated(
|
||||
_dispatch_nonempty(content, dispatch_event, num)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
if sep_callback:
|
||||
sep_callback(num)
|
||||
|
||||
return streamed_tokens
|
||||
|
||||
|
||||
@@ -333,8 +359,12 @@ def retrieve_search_docs(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
@@ -364,8 +394,24 @@ def summarize_history(
|
||||
)
|
||||
)
|
||||
|
||||
history_response = llm.invoke(history_context_prompt)
|
||||
try:
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
|
||||
assert isinstance(history_response.content, str)
|
||||
|
||||
return history_response.content
|
||||
|
||||
|
||||
@@ -431,3 +477,27 @@ def remove_document_citations(text: str) -> str:
|
||||
# \d+ - one or more digits
|
||||
# \] - literal ] character
|
||||
return re.sub(r"\[(?:D|Q)?\d+\]", "", text)
|
||||
|
||||
|
||||
def get_deduplicated_structured_subquestion_documents(
|
||||
sub_question_results: list[SubQuestionAnswerResults],
|
||||
) -> StructuredSubquestionDocuments:
|
||||
"""
|
||||
Extract and deduplicate all cited documents from sub-question results.
|
||||
|
||||
Args:
|
||||
sub_question_results: List of sub-question results containing cited documents
|
||||
|
||||
Returns:
|
||||
Deduplicated list of cited documents
|
||||
"""
|
||||
cited_docs = [
|
||||
doc for result in sub_question_results for doc in result.cited_documents
|
||||
]
|
||||
context_docs = [
|
||||
doc for result in sub_question_results for doc in result.context_documents
|
||||
]
|
||||
return StructuredSubquestionDocuments(
|
||||
cited_documents=dedup_inference_section_list(cited_docs),
|
||||
context_documents=dedup_inference_section_list(context_docs),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,23 +13,150 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width" />
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body, table, td, a {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
||||
text-size-adjust: 100%;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-webkit-text-size-adjust: none;
|
||||
}}
|
||||
body {{
|
||||
background-color: #f7f7f7;
|
||||
color: #333;
|
||||
}}
|
||||
.body-content {{
|
||||
color: #333;
|
||||
}}
|
||||
.email-container {{
|
||||
width: 100%;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 6px;
|
||||
overflow: hidden;
|
||||
border: 1px solid #eaeaea;
|
||||
}}
|
||||
.header {{
|
||||
background-color: #000000;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
}}
|
||||
.title {{
|
||||
font-size: 20px;
|
||||
font-weight: bold;
|
||||
margin: 0 0 10px;
|
||||
}}
|
||||
.message {{
|
||||
font-size: 16px;
|
||||
line-height: 1.5;
|
||||
margin: 0 0 20px;
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
color: #6A7280;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}}
|
||||
.footer a {{
|
||||
color: #6b7280;
|
||||
text-decoration: underline;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table role="presentation" class="email-container" cellpadding="0" cellspacing="0">
|
||||
<tr>
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="body-content">
|
||||
<h1 class="title">{heading}</h1>
|
||||
<div class="message">
|
||||
{message}
|
||||
</div>
|
||||
{cta_block}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def build_html_email(
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
) -> str:
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -40,26 +167,44 @@ def send_email(
|
||||
raise e
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
# Example usage of the reusable HTML
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
heading = "You've Been Invited!"
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -68,13 +213,15 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
# Keep search param same name as cookie for simplicity
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
@@ -82,7 +229,12 @@ def send_user_verification_email(
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
@@ -1,41 +1,56 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""This scheduler is useful because we can dynamically adjust task generation rates
|
||||
through it."""
|
||||
|
||||
RELOAD_INTERVAL = 60
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
|
||||
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
self._reload_interval = timedelta(
|
||||
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
|
||||
)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._try_updating_schedule()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
task_logger.info(
|
||||
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
|
||||
)
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
@@ -44,36 +59,35 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
task_logger.debug("Reload interval reached, initiating task update")
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
except (AttributeError, KeyError):
|
||||
task_logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected error updating tasks")
|
||||
|
||||
self._last_reload = now
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
|
||||
return retval
|
||||
|
||||
def _generate_schedule(
|
||||
self, tenant_ids: list[str] | list[None]
|
||||
self, tenant_ids: list[str] | list[None], beat_multiplier: float
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Given a list of tenant id's, generates a new beat schedule for celery."""
|
||||
logger.info("Fetching tasks to schedule")
|
||||
|
||||
new_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if MULTI_TENANT:
|
||||
# cloud tasks only need the single task beat across all tenants
|
||||
# cloud tasks are system wide and thus only need to be on the beat schedule
|
||||
# once for all tenants
|
||||
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule",
|
||||
"get_cloud_tasks_to_schedule",
|
||||
)
|
||||
|
||||
cloud_tasks_to_schedule: list[
|
||||
dict[str, Any]
|
||||
] = get_cloud_tasks_to_schedule()
|
||||
cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
|
||||
beat_multiplier
|
||||
)
|
||||
for task in cloud_tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
cloud_task = {
|
||||
@@ -82,11 +96,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
task_logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
cloud_task["options"] = options
|
||||
new_schedule[task_name] = cloud_task
|
||||
|
||||
# regular task beats are multiplied across all tenants
|
||||
# note that currently this just schedules for a single tenant in self hosted
|
||||
# and doesn't do anything in the cloud because it's much more scalable
|
||||
# to schedule a single cloud beat task to dispatch per tenant tasks.
|
||||
get_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
@@ -95,7 +112,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
logger.info(
|
||||
task_logger.debug(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
@@ -104,14 +121,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
task_name = task["name"]
|
||||
tenant_task_name = f"{task['name']}-{tenant_id}"
|
||||
|
||||
logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
task_logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
tenant_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(
|
||||
task_logger.debug(
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
@@ -121,44 +138,57 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
def _try_updating_schedule(self) -> None:
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
logger.info("_try_updating_schedule starting")
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
task_logger.debug(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
# if the schedule or beat multiplier has changed, update
|
||||
while True:
|
||||
if beat_multiplier != self.last_beat_multiplier:
|
||||
do_update = True
|
||||
break
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
if not DynamicTenantScheduler._compare_schedules(
|
||||
current_schedule, new_schedule
|
||||
):
|
||||
do_update = True
|
||||
break
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
break
|
||||
|
||||
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
|
||||
logger.info(
|
||||
"_try_updating_schedule: Current schedule is up to date, no changes needed"
|
||||
if not do_update:
|
||||
# exit early if nothing changed
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule unchanged: "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
# schedule needs updating
|
||||
task_logger.debug(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_schedule),
|
||||
@@ -185,11 +215,19 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("_try_updating_schedule: Schedule updated successfully")
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule updated: "
|
||||
f"prev_num_tasks={len(current_schedule)} "
|
||||
f"prev_beat_multiplier={self.last_beat_multiplier} "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
|
||||
self.last_beat_multiplier = beat_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
|
||||
"""Compare schedules to determine if an update is needed.
|
||||
"""Compare schedules by task name only to determine if an update is needed.
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
@@ -201,7 +239,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
task_logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
||||
@@ -84,8 +84,10 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
@@ -142,7 +144,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
@@ -18,242 +19,203 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 4
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule = [
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
|
||||
beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task: dict[str, Any] = {}
|
||||
|
||||
# constant options for cloud beat task generators
|
||||
task_schedule: timedelta = task["schedule"]
|
||||
cloud_task["schedule"] = task_schedule
|
||||
cloud_task["options"] = {}
|
||||
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
|
||||
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
|
||||
|
||||
# settings dependent on the original task
|
||||
cloud_task["name"] = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_{task['name']}"
|
||||
cloud_task["task"] = OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
|
||||
return cloud_task
|
||||
|
||||
|
||||
# tasks that only run in the cloud and are system wide
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
|
||||
# by the DynamicTenantScheduler as system wide task and not a per tenant task
|
||||
beat_cloud_tasks: list[dict] = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
# remaining tasks are cloud generators for per tenant tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-queues",
|
||||
"task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
cloud_tasks_to_schedule.append(
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
# tasks that only run self hosted
|
||||
tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"name": "monitor-celery-queues",
|
||||
"task": OnyxCeleryTask.MONITOR_CELERY_QUEUES,
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
beat_tasks: system wide tasks that can be sent as is
|
||||
beat_templates: task templates that will be transformed into per tenant tasks via
|
||||
the cloud_beat_task_generator
|
||||
beat_multiplier: a multiplier that can be applied on top of the task schedule
|
||||
to speed up or slow down the task generation rate. useful in production.
|
||||
|
||||
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
|
||||
from incoming templates.
|
||||
"""
|
||||
|
||||
if beat_multiplier <= 0:
|
||||
raise ValueError("beat_multiplier must be positive!")
|
||||
|
||||
cloud_tasks: list[dict] = []
|
||||
|
||||
# generate our tenant aware cloud tasks from the templates
|
||||
for beat_template in beat_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_template)
|
||||
cloud_tasks.append(cloud_task)
|
||||
|
||||
# factor in the cloud multiplier for the above
|
||||
for cloud_task in cloud_tasks:
|
||||
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
|
||||
|
||||
# add the fixed cloud/system beat tasks. No multiplier for these.
|
||||
cloud_tasks.extend(copy.deepcopy(beat_tasks))
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
return generate_cloud_tasks(beat_cloud_tasks, beat_task_templates, beat_multiplier)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,18 +16,35 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
@@ -42,6 +63,7 @@ def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -77,6 +99,18 @@ def check_for_connector_deletion_task(
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -179,11 +213,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
@@ -209,3 +246,158 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
@@ -175,6 +175,24 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -228,12 +246,15 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.permissions.set_active()
|
||||
@@ -257,11 +278,10 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.celery_task_id
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@@ -290,6 +310,8 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
@@ -332,9 +354,12 @@ def connector_permission_sync_generator_task(
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
payload_id = payload.id
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
f"fence={redis_connector.permissions.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -342,6 +367,7 @@ def connector_permission_sync_generator_task(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
@@ -413,7 +439,9 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
task_logger.exception(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
@@ -423,6 +451,10 @@ def connector_permission_sync_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Permission sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
@@ -446,14 +478,15 @@ def update_external_document_permissions_task(
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
continue_on_error=True,
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
# Then upsert the document's external permissions
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
@@ -477,11 +510,11 @@ def update_external_document_permissions_task(
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc_id={doc_id}"
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -659,7 +692,7 @@ def validate_permission_sync_fence(
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're only active if tasks_scanned > 0 and tasks_not_in_celery == 0
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
@@ -680,7 +713,8 @@ def validate_permission_sync_fence(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
@@ -741,7 +775,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
"""Monitoring CCPair permissions utils"""
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
|
||||
@@ -2,15 +2,17 @@ import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@@ -24,15 +26,17 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
@@ -49,7 +53,8 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -67,18 +72,26 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
task_logger.error(
|
||||
f"Recieved non-sync CC Pair {cc_pair.id} for external "
|
||||
f"group sync. Actual access type: {cc_pair.access_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"CC Pair is being deleted"
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync function for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
@@ -107,11 +120,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -120,6 +133,9 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.warning(
|
||||
f"Failed to acquire beat lock for external group sync: {tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -149,30 +165,32 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
task_logger.info(
|
||||
f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating external group sync fences"
|
||||
)
|
||||
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -191,35 +209,46 @@ def try_creating_external_group_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
logger.warning(
|
||||
f"Skipping external group sync for CC Pair {cc_pair_id} - already running."
|
||||
)
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# Signal active before creating fence
|
||||
redis_connector.external_group_sync.set_active()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
@@ -234,27 +263,17 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
return payload_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -285,22 +304,26 @@ def connector_external_group_sync_generator_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
msg = "connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -312,7 +335,8 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
f"fence={redis_connector.external_group_sync.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -324,9 +348,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
msg = f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
task_logger.error(msg)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -347,9 +371,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
@@ -380,9 +404,9 @@ def connector_external_group_sync_generator_task(
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
@@ -401,32 +425,41 @@ def connector_external_group_sync_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"External group sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
reserved_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
# validate all existing external group sync tasks
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorExternalGroupSync.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_tasks,
|
||||
r_celery,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -435,7 +468,6 @@ def validate_external_group_sync_fence(
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
@@ -464,9 +496,11 @@ def validate_external_group_sync_fence(
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
msg = (
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
emit_background_error(msg)
|
||||
task_logger.error(msg)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
@@ -478,26 +512,28 @@ def validate_external_group_sync_fence(
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
try:
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
except ValidationError:
|
||||
msg = (
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
@@ -523,11 +559,15 @@ def validate_external_group_sync_fence(
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
emit_background_error(
|
||||
message=(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -6,13 +6,18 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -30,6 +35,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -37,6 +43,7 @@ from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
@@ -47,9 +54,12 @@ from onyx.db.swap_index import check_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -60,6 +70,150 @@ from shared_configs.configs import SENTRY_DSN
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
soft_time_limit=300,
|
||||
@@ -91,6 +245,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
try:
|
||||
locked = True
|
||||
|
||||
# SPECIAL 0/3: sync lookup table for active fences
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
|
||||
# build a lookup table of existing fences
|
||||
# this is just a migration concern and should be unnecessary once
|
||||
# lookup tables are rolled out
|
||||
for key_bytes in redis_client_replica.scan_iter(
|
||||
count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if is_fence(key_bytes) and not redis_client.sismember(
|
||||
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
|
||||
):
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
@@ -197,6 +370,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# 2/3: VALIDATE
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -236,6 +411,26 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
# 3/3: FINALIZE
|
||||
lock_beat.reacquire()
|
||||
keys = cast(
|
||||
set[Any], redis_client_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
)
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not redis_client.exists(key_bytes):
|
||||
redis_client.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(
|
||||
tenant_id, key_bytes, redis_client_replica, db_session
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -423,8 +618,8 @@ def connector_indexing_task(
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -99,16 +99,16 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
@@ -120,7 +120,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -143,6 +143,8 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_connector.prune.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -165,7 +167,9 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
|
||||
@@ -17,7 +17,8 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -189,9 +190,9 @@ def _build_connector_start_latency_metric(
|
||||
desired_start_time = cc_pair.connector.time_created
|
||||
else:
|
||||
if not cc_pair.connector.refresh_freq:
|
||||
task_logger.error(
|
||||
"Found non-initial index attempt for connector "
|
||||
"without refresh_freq. This should never happen."
|
||||
task_logger.debug(
|
||||
"Connector has no refresh_freq and this is a non-initial index attempt. "
|
||||
"Assuming user manually triggered indexing, so we'll skip start latency metric."
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -420,6 +421,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
|
||||
- Throughput (docs/min) (only if success)
|
||||
- Raw start/end times for each sync
|
||||
"""
|
||||
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records that ended in the last hour
|
||||
@@ -587,6 +589,10 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Only user groups and document set sync records have
|
||||
# an associated entity we can use for latency metrics
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
@@ -717,7 +723,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
name=OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
@@ -777,7 +783,7 @@ def cloud_check_alembic() -> bool | None:
|
||||
|
||||
tenant_to_revision[tenant_id] = result_scalar
|
||||
except Exception:
|
||||
task_logger.warning(f"Tenant {tenant_id} has no revision!")
|
||||
task_logger.error(f"Tenant {tenant_id} has no revision!")
|
||||
tenant_to_revision[tenant_id] = ALEMBIC_NULL_REVISION
|
||||
|
||||
# get the total count of each revision
|
||||
@@ -847,3 +853,55 @@ def cloud_check_alembic() -> bool | None:
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, ignore_result=True, bind=True
|
||||
)
|
||||
def cloud_monitor_celery_queues(
|
||||
self: Task,
|
||||
) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
def monitor_celery_queues_helper(
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
n_indexing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(n_indexing_prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
@@ -1,28 +1,39 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
@@ -35,10 +46,15 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -93,6 +109,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -104,32 +122,68 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
# the entire task needs to run frequently in order to finalize pruning
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=3600)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -149,7 +203,7 @@ def try_creating_prune_generator_task(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -168,7 +222,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
# we need to serialize starting pruning since it can be triggered either via
|
||||
# celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -200,7 +254,30 @@ def try_creating_prune_generator_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# signal active before the fence is set
|
||||
redis_connector.prune.set_active()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPrunePayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -213,16 +290,11 @@ def try_creating_prune_generator_task(
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
# fill in the celery task id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
@@ -230,7 +302,7 @@ def try_creating_prune_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
return payload_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -252,6 +324,8 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
@@ -265,6 +339,46 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.prune.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - fence not found: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.prune.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_prune_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
payload_id = payload.id
|
||||
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.prune.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
@@ -294,6 +408,18 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
payload = redis_connector.prune.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPrunePayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,10 +433,13 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
@@ -357,7 +486,9 @@ def connector_pruning_generator_task(
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
f"connector={connector_id} "
|
||||
f"payload_id={payload_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
@@ -366,10 +497,12 @@ def connector_pruning_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
task_logger.info(
|
||||
f"Pruning generator finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
"""Monitoring pruning utils"""
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
@@ -415,4 +548,184 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
redis_connector.prune.set_fence(None)
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
# the queue for a single pruning generator task
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
|
||||
# the queue for a reasonably large set of lightweight deletion tasks
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_pruning_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_generator_tasks,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""See validate_indexing_fence for an overall idea of validation flows.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_pruning_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.prune.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_pruning_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own pruning taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.prune.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_pruning_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.prune.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_pruning_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
return
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from tenacity import RetryError
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
@@ -252,7 +253,11 @@ def cloud_beat_task_generator(
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
@@ -270,6 +275,7 @@ def cloud_beat_task_generator(
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -13,8 +9,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,47 +16,27 @@ from tenacity import RetryError
|
||||
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
monitor_ccpair_permissions_taskset,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import count_documents_by_needs_sync
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import delete_document_set
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -72,20 +46,14 @@ from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -94,7 +62,6 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -111,9 +78,14 @@ logger = setup_logger()
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
|
||||
task_logger.info("check_for_vespa_sync_task started")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -125,6 +97,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1/3: KICKOFF
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
@@ -151,9 +124,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
# endregion
|
||||
|
||||
# check if any user groups are not synced
|
||||
lock_beat.reacquire()
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -179,6 +151,35 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# 2/3: VALIDATE: TODO
|
||||
|
||||
# 3/3: FINALIZE
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -339,11 +340,15 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -411,11 +416,15 @@ def try_generate_user_group_sync_tasks(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
|
||||
@@ -487,484 +496,23 @@ def monitor_document_set_taskset(
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# print current queue lengths
|
||||
time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 10) == 10:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
"update_sync_record_status exceptioned. "
|
||||
f"document_set_id={document_set_id} "
|
||||
"Resetting document set regardless."
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
|
||||
# build a lookup table of existing fences
|
||||
# this is just a migration concern and should be unnecessary once
|
||||
# lookup tables are rolled out
|
||||
for key_bytes in r_replica.scan_iter(count=SCAN_ITER_COUNT_DEFAULT):
|
||||
if is_fence(key_bytes) and not r.sismember(
|
||||
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
|
||||
):
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
keys = cast(set[Any], r.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
elif key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
else:
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
# f"timings={timings}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
rds.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -1064,23 +612,3 @@ def vespa_metadata_sync_task(
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_fence(key_bytes: bytes) -> bool:
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
return True
|
||||
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
13
backend/onyx/background/error_logging.py
Normal file
13
backend/onyx/background/error_logging.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from onyx.db.background_error import create_background_error
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
def emit_background_error(
|
||||
message: str,
|
||||
cc_pair_id: int | None = None,
|
||||
) -> None:
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
@@ -140,6 +140,7 @@ def build_citations_user_message(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
context_type: str = "context documents",
|
||||
) -> HumanMessage:
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(
|
||||
@@ -156,6 +157,7 @@ def build_citations_user_message(
|
||||
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
||||
|
||||
user_prompt = CITATIONS_PROMPT.format(
|
||||
context_type=context_type,
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
@@ -165,6 +167,7 @@ def build_citations_user_message(
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
context_type=context_type,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
|
||||
@@ -3,6 +3,24 @@ import os
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
|
||||
ALLOW_REFINEMENT = True
|
||||
|
||||
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
||||
AGENT_DEFAULT_RERANKING_HITS = 10
|
||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5
|
||||
|
||||
AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = 25
|
||||
AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = 35
|
||||
|
||||
|
||||
AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5
|
||||
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
|
||||
ALLOW_REFINEMENT = True
|
||||
|
||||
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
||||
AGENT_DEFAULT_RERANKING_HITS = 10
|
||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||
@@ -13,9 +31,21 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
@@ -77,4 +107,151 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
) # 2000
|
||||
|
||||
AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = int(
|
||||
os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER")
|
||||
or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
) # 25
|
||||
|
||||
AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = int(
|
||||
os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER")
|
||||
or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
) # 35
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
) or True # default True
|
||||
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
# Reranking agent configs
|
||||
# Reranking stats - no influence on flow outside of stats collection
|
||||
AGENT_RERANKING_STATS = (
|
||||
not os.environ.get("AGENT_RERANKING_STATS") == "True"
|
||||
) or False # default False
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS")
|
||||
or AGENT_DEFAULT_RERANKING_HITS
|
||||
) # 10
|
||||
|
||||
AGENT_NUM_DOCS_FOR_DECOMPOSITION = int(
|
||||
os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION")
|
||||
or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION
|
||||
) # 3
|
||||
|
||||
AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int(
|
||||
os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION")
|
||||
or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
|
||||
) # 5
|
||||
|
||||
AGENT_EXPLORATORY_SEARCH_RESULTS = int(
|
||||
os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS")
|
||||
or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS
|
||||
) # 5
|
||||
|
||||
AGENT_MIN_ORIG_QUESTION_DOCS = int(
|
||||
os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS")
|
||||
or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS
|
||||
) # 3
|
||||
|
||||
AGENT_MAX_ANSWER_CONTEXT_DOCS = int(
|
||||
os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS")
|
||||
or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
os.environ.get("AGENT_MAX_STATIC_HISTORY_WORD_LENGTH")
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
) # 8
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -107,9 +107,9 @@ CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
@@ -125,6 +125,7 @@ class DocumentSource(str, Enum):
|
||||
GMAIL = "gmail"
|
||||
REQUESTTRACKER = "requesttracker"
|
||||
GITHUB = "github"
|
||||
GITBOOK = "gitbook"
|
||||
GITLAB = "gitlab"
|
||||
GURU = "guru"
|
||||
BOOKSTACK = "bookstack"
|
||||
@@ -263,6 +264,11 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
# Tasks running in this queue should be designed specifically to run quickly
|
||||
PRIMARY = "celery"
|
||||
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
@@ -293,7 +299,6 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
@@ -319,6 +324,8 @@ class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
|
||||
"signal:block_validate_permission_sync_fences"
|
||||
)
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
|
||||
|
||||
@@ -340,12 +347,18 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
|
||||
# the tenant id we use for system level redis operations
|
||||
ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
|
||||
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
|
||||
CLOUD_MONITOR_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_alembic"
|
||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||
)
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -355,8 +368,8 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
@@ -65,10 +65,25 @@ class AirtableConnector(LoadConnector):
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
view_id: str | None = None,
|
||||
share_id: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Initialize an AirtableConnector.
|
||||
|
||||
Args:
|
||||
base_id: The ID of the Airtable base to connect to
|
||||
table_name_or_id: The name or ID of the table to index
|
||||
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
|
||||
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
|
||||
view_id: Optional ID of a specific view to use
|
||||
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
|
||||
batch_size: Number of records to process in each batch
|
||||
"""
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.view_id = view_id
|
||||
self.share_id = share_id
|
||||
self.batch_size = batch_size
|
||||
self._airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
@@ -85,6 +100,39 @@ class AirtableConnector(LoadConnector):
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
@classmethod
|
||||
def _get_record_url(
|
||||
cls,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
share_id: str | None,
|
||||
view_id: str | None,
|
||||
field_id: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
) -> str:
|
||||
"""Constructs the URL for a record, optionally including field and attachment IDs
|
||||
|
||||
Full possible structure is:
|
||||
|
||||
https://airtable.com/BASE_ID/SHARE_ID/TABLE_ID/VIEW_ID/RECORD_ID/FIELD_ID/ATTACHMENT_ID
|
||||
"""
|
||||
# If we have a shared link, use that view for better UX
|
||||
if share_id:
|
||||
base_url = f"https://airtable.com/{base_id}/{share_id}/{table_id}"
|
||||
else:
|
||||
base_url = f"https://airtable.com/{base_id}/{table_id}"
|
||||
|
||||
if view_id:
|
||||
base_url = f"{base_url}/{view_id}"
|
||||
|
||||
base_url = f"{base_url}/{record_id}"
|
||||
|
||||
if field_id and attachment_id:
|
||||
return f"{base_url}/{field_id}/{attachment_id}?blocks=hide"
|
||||
|
||||
return base_url
|
||||
|
||||
def _extract_field_values(
|
||||
self,
|
||||
field_id: str,
|
||||
@@ -110,8 +158,10 @@ class AirtableConnector(LoadConnector):
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
# default link to use for non-attachment fields
|
||||
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
# Get the base URL for this record
|
||||
default_link = self._get_record_url(
|
||||
base_id, table_id, record_id, self.share_id, self.view_id or view_id
|
||||
)
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[tuple[str, str]] = []
|
||||
@@ -165,17 +215,16 @@ class AirtableConnector(LoadConnector):
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
# slightly nicer loading experience if we can specify the view ID
|
||||
if view_id:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
else:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
# Use the helper method to construct attachment URLs
|
||||
attachment_link = self._get_record_url(
|
||||
base_id,
|
||||
table_id,
|
||||
record_id,
|
||||
self.share_id,
|
||||
self.view_id or view_id,
|
||||
field_id,
|
||||
attachment_id,
|
||||
)
|
||||
attachment_texts.append(
|
||||
(f"{filename}:\n{attachment_text}", attachment_link)
|
||||
)
|
||||
|
||||
@@ -61,6 +61,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
batch = bookstack_client.get(endpoint, params=params).get("data", [])
|
||||
|
||||
doc_batch = [transformer(bookstack_client, item) for item in batch]
|
||||
|
||||
return doc_batch, len(batch)
|
||||
@@ -197,20 +198,31 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
for endpoint, transform in transform_by_endpoint.items():
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, num_results = self._get_doc_batch(
|
||||
batch_size=self.batch_size,
|
||||
bookstack_client=self.bookstack_client,
|
||||
endpoint=endpoint,
|
||||
transformer=transform,
|
||||
start_ind=start_ind,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
start_ind += num_results
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
try:
|
||||
doc_batch, num_results = self._get_doc_batch(
|
||||
batch_size=self.batch_size,
|
||||
bookstack_client=self.bookstack_client,
|
||||
endpoint=endpoint,
|
||||
transformer=transform,
|
||||
start_ind=start_ind,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
start_ind += num_results
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
if num_results < self.batch_size:
|
||||
if num_results < self.batch_size:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
except Exception as e:
|
||||
# Handle case where user hasn't properly set up permissions for the API key and we
|
||||
# fail on a specific resource (e.g. /books, /chapters, etc.)
|
||||
|
||||
if (
|
||||
"BookStack Client request failed with status 403: Forbidden"
|
||||
in str(e)
|
||||
):
|
||||
raise
|
||||
break
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -319,6 +320,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
@@ -386,4 +388,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.gitlab.connector import GitlabConnector
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
@@ -71,6 +72,7 @@ def identify_connector_class(
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GITBOOK: GitbookConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
|
||||
0
backend/onyx/connectors/gitbook/__init__.py
Normal file
0
backend/onyx/connectors/gitbook/__init__.py
Normal file
279
backend/onyx/connectors/gitbook/connector.py
Normal file
279
backend/onyx/connectors/gitbook/connector.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GITBOOK_API_BASE = "https://api.gitbook.com/v1/"
|
||||
|
||||
|
||||
class GitbookApiClient:
|
||||
def __init__(self, access_token: str) -> None:
|
||||
self.access_token = access_token
|
||||
|
||||
def get(self, endpoint: str, params: dict[str, Any] | None = None) -> Any:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = urljoin(GITBOOK_API_BASE, endpoint.lstrip("/"))
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_page_content(self, space_id: str, page_id: str) -> dict[str, Any]:
|
||||
return self.get(f"/spaces/{space_id}/content/page/{page_id}")
|
||||
|
||||
|
||||
def _extract_text_from_document(document: dict[str, Any]) -> str:
|
||||
"""Extract text content from GitBook document structure by parsing the document nodes
|
||||
into markdown format."""
|
||||
|
||||
def parse_leaf(leaf: dict[str, Any]) -> str:
|
||||
text = leaf.get("text", "")
|
||||
leaf.get("marks", [])
|
||||
return text
|
||||
|
||||
def parse_text_node(node: dict[str, Any]) -> str:
|
||||
text = ""
|
||||
for leaf in node.get("leaves", []):
|
||||
text += parse_leaf(leaf)
|
||||
return text
|
||||
|
||||
def parse_block_node(node: dict[str, Any]) -> str:
|
||||
block_type = node.get("type", "")
|
||||
result = ""
|
||||
|
||||
if block_type == "heading-1":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"# {text}\n\n"
|
||||
|
||||
elif block_type == "heading-2":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"## {text}\n\n"
|
||||
|
||||
elif block_type == "heading-3":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"### {text}\n\n"
|
||||
|
||||
elif block_type == "heading-4":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"#### {text}\n\n"
|
||||
|
||||
elif block_type == "heading-5":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"##### {text}\n\n"
|
||||
|
||||
elif block_type == "heading-6":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"###### {text}\n\n"
|
||||
|
||||
elif block_type == "list-unordered":
|
||||
for list_item in node.get("nodes", []):
|
||||
paragraph = list_item.get("nodes", [])[0]
|
||||
text = "".join(parse_text_node(n) for n in paragraph.get("nodes", []))
|
||||
result += f"* {text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "paragraph":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"{text}\n\n"
|
||||
|
||||
elif block_type == "list-tasks":
|
||||
for task_item in node.get("nodes", []):
|
||||
checked = task_item.get("data", {}).get("checked", False)
|
||||
paragraph = task_item.get("nodes", [])[0]
|
||||
text = "".join(parse_text_node(n) for n in paragraph.get("nodes", []))
|
||||
checkbox = "[x]" if checked else "[ ]"
|
||||
result += f"- {checkbox} {text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "code":
|
||||
for code_line in node.get("nodes", []):
|
||||
if code_line.get("type") == "code-line":
|
||||
text = "".join(
|
||||
parse_text_node(n) for n in code_line.get("nodes", [])
|
||||
)
|
||||
result += f"{text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "blockquote":
|
||||
for quote_node in node.get("nodes", []):
|
||||
if quote_node.get("type") == "paragraph":
|
||||
text = "".join(
|
||||
parse_text_node(n) for n in quote_node.get("nodes", [])
|
||||
)
|
||||
result += f"> {text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "table":
|
||||
records = node.get("data", {}).get("records", {})
|
||||
definition = node.get("data", {}).get("definition", {})
|
||||
view = node.get("data", {}).get("view", {})
|
||||
|
||||
columns = view.get("columns", [])
|
||||
|
||||
header_cells = []
|
||||
for col_id in columns:
|
||||
col_def = definition.get(col_id, {})
|
||||
header_cells.append(col_def.get("title", ""))
|
||||
|
||||
result = "| " + " | ".join(header_cells) + " |\n"
|
||||
result += "|" + "---|" * len(header_cells) + "\n"
|
||||
|
||||
sorted_records = sorted(
|
||||
records.items(), key=lambda x: x[1].get("orderIndex", "")
|
||||
)
|
||||
|
||||
for record_id, record_data in sorted_records:
|
||||
values = record_data.get("values", {})
|
||||
row_cells = []
|
||||
for col_id in columns:
|
||||
fragment_id = values.get(col_id, "")
|
||||
fragment_text = ""
|
||||
for fragment in node.get("fragments", []):
|
||||
if fragment.get("fragment") == fragment_id:
|
||||
for frag_node in fragment.get("nodes", []):
|
||||
if frag_node.get("type") == "paragraph":
|
||||
fragment_text = "".join(
|
||||
parse_text_node(n)
|
||||
for n in frag_node.get("nodes", [])
|
||||
)
|
||||
break
|
||||
row_cells.append(fragment_text)
|
||||
result += "| " + " | ".join(row_cells) + " |\n"
|
||||
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
if not document or "document" not in document:
|
||||
return ""
|
||||
|
||||
markdown = ""
|
||||
nodes = document["document"].get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
markdown += parse_block_node(node)
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
def _convert_page_to_document(
|
||||
client: GitbookApiClient, space_id: str, page: dict[str, Any]
|
||||
) -> Document:
|
||||
page_id = page["id"]
|
||||
page_content = client.get_page_content(space_id, page_id)
|
||||
|
||||
return Document(
|
||||
id=f"gitbook-{space_id}-{page_id}",
|
||||
sections=[
|
||||
Section(
|
||||
link=page.get("urls", {}).get("app", ""),
|
||||
text=_extract_text_from_document(page_content),
|
||||
)
|
||||
],
|
||||
source=DocumentSource.GITBOOK,
|
||||
semantic_identifier=page.get("title", ""),
|
||||
doc_updated_at=datetime.fromisoformat(page["updatedAt"]).replace(
|
||||
tzinfo=timezone.utc
|
||||
),
|
||||
metadata={
|
||||
"path": page.get("path", ""),
|
||||
"type": page.get("type", ""),
|
||||
"kind": page.get("kind", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class GitbookConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
space_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.space_id = space_id
|
||||
self.batch_size = batch_size
|
||||
self.access_token: str | None = None
|
||||
self.client: GitbookApiClient | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
access_token = credentials.get("gitbook_api_key")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("GitBook access token")
|
||||
self.access_token = access_token
|
||||
self.client = GitbookApiClient(access_token)
|
||||
|
||||
def _fetch_all_pages(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if not self.client:
|
||||
raise ConnectorMissingCredentialError("GitBook")
|
||||
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content")
|
||||
pages = content.get("pages", [])
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for page in pages:
|
||||
updated_at = datetime.fromisoformat(page["updatedAt"])
|
||||
|
||||
if start and updated_at < start:
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
return
|
||||
if end and updated_at > end:
|
||||
continue
|
||||
|
||||
current_batch.append(
|
||||
_convert_page_to_document(self.client, self.space_id, page)
|
||||
)
|
||||
|
||||
if len(current_batch) >= self.batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error fetching GitBook content: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_all_pages()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
return self._fetch_all_pages(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = GitbookConnector(
|
||||
space_id=os.environ["GITBOOK_SPACE_ID"],
|
||||
)
|
||||
connector.load_credentials({"gitbook_api_key": os.environ["GITBOOK_API_KEY"]})
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -321,6 +322,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
@@ -343,6 +345,15 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
@@ -368,9 +379,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._fetch_slim_threads(start, end)
|
||||
yield from self._fetch_slim_threads(start, end, callback=callback)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -301,7 +302,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.exception(
|
||||
logger.warning(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
@@ -564,6 +565,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
@@ -576,15 +578,26 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_extract_slim_docs_from_google_drive: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._extract_slim_docs_from_google_drive(start, end)
|
||||
yield from self._extract_slim_docs_from_google_drive(
|
||||
start, end, callback=callback
|
||||
)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
@@ -63,6 +64,7 @@ class SlimConnector(BaseConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
f"&prompt=consent" # prompts user for access; allows choosing workspace
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.connectors.onyx_jira.utils import build_jira_url
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.onyx_jira.utils import get_comment_strs
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -144,7 +145,8 @@ def fetch_jira_issues_batch(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=issue.fields.summary,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
@@ -245,6 +247,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -176,6 +177,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -242,6 +243,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -98,6 +99,7 @@ def get_channel_messages(
|
||||
channel: dict[str, Any],
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
@@ -115,6 +117,11 @@ def get_channel_messages(
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("get_channel_messages: Stop signal detected")
|
||||
|
||||
callback.progress("get_channel_messages", 0)
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
@@ -325,6 +332,7 @@ def _get_all_doc_ids(
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -342,6 +350,7 @@ def _get_all_doc_ids(
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
@@ -390,6 +399,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
@@ -398,6 +408,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
|
||||
@@ -39,19 +39,6 @@ def get_message_link(
|
||||
return permalink
|
||||
|
||||
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
def logged_call(**kwargs: Any) -> SlackResponse:
|
||||
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
|
||||
result = call(**kwargs)
|
||||
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
|
||||
return result
|
||||
|
||||
return logged_call
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
@@ -127,18 +114,14 @@ def make_slack_api_rate_limited(
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
basic_retry_wrapper(make_slack_api_rate_limited(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
@@ -405,6 +406,7 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
if self.content_type == "articles":
|
||||
|
||||
@@ -51,6 +51,7 @@ class SearchPipeline:
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: (
|
||||
@@ -61,10 +62,13 @@ class SearchPipeline:
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.skip_query_analysis = skip_query_analysis
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
@@ -106,6 +110,7 @@ class SearchPipeline:
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
skip_query_analysis=self.skip_query_analysis,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
@@ -160,6 +165,12 @@ class SearchPipeline:
|
||||
that have a corresponding chunk.
|
||||
|
||||
This step should be fast for any document index implementation.
|
||||
|
||||
Current implementation timing is approximately broken down in timing as:
|
||||
- 200 ms to get the embedding of the query
|
||||
- 15 ms to get chunks from the document index
|
||||
- possibly more to get additional surrounding chunks
|
||||
- possibly more for query expansion (multilingual)
|
||||
"""
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RerankMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def semantic_reranking(
|
||||
query: SearchQuery,
|
||||
query_str: str,
|
||||
rerank_settings: RerankingDetails,
|
||||
chunks: list[InferenceChunk],
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
@@ -88,11 +90,9 @@ def semantic_reranking(
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
rerank_settings = query.rerank_settings
|
||||
|
||||
if not rerank_settings or not rerank_settings.rerank_model_name:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking flow should not be running")
|
||||
assert (
|
||||
rerank_settings.rerank_model_name
|
||||
), "Reranking flow cannot run without a specific model"
|
||||
|
||||
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
|
||||
|
||||
@@ -107,7 +107,7 @@ def semantic_reranking(
|
||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||
for chunk in chunks_to_rerank
|
||||
]
|
||||
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
|
||||
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
|
||||
|
||||
# Old logic to handle multiple cross-encoders preserved but not used
|
||||
sim_scores = [numpy.array(sim_scores_floats)]
|
||||
@@ -165,8 +165,20 @@ def semantic_reranking(
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
|
||||
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
|
||||
- rerank_model_name is not None
|
||||
- num_rerank is greater than 0
|
||||
"""
|
||||
if not rerank_settings:
|
||||
return False
|
||||
|
||||
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
|
||||
|
||||
|
||||
def rerank_sections(
|
||||
query: SearchQuery,
|
||||
query_str: str,
|
||||
rerank_settings: RerankingDetails,
|
||||
sections_to_rerank: list[InferenceSection],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceSection]:
|
||||
@@ -181,16 +193,13 @@ def rerank_sections(
|
||||
"""
|
||||
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
|
||||
|
||||
if not query.rerank_settings:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking settings not found")
|
||||
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query,
|
||||
query_str=query_str,
|
||||
rerank_settings=rerank_settings,
|
||||
chunks=chunks_to_rerank,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
|
||||
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
|
||||
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
# However the ordering is still important
|
||||
@@ -260,16 +269,13 @@ def search_postprocessing(
|
||||
|
||||
rerank_task_id = None
|
||||
sections_yielded = False
|
||||
if (
|
||||
search_query.rerank_settings
|
||||
and search_query.rerank_settings.rerank_model_name
|
||||
and search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
if should_rerank(search_query.rerank_settings):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_sections,
|
||||
(
|
||||
search_query,
|
||||
search_query.query,
|
||||
search_query.rerank_settings, # Cannot be None here
|
||||
retrieved_sections,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
|
||||
@@ -50,11 +50,11 @@ def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
skip_query_analysis: bool = False,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
bypass_acl: bool = False,
|
||||
) -> SearchQuery:
|
||||
"""Logic is as follows:
|
||||
Any global disables apply first
|
||||
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (None, None)
|
||||
else (False, None)
|
||||
)
|
||||
|
||||
all_query_terms = query.split()
|
||||
|
||||
10
backend/onyx/db/background_error.py
Normal file
10
backend/onyx/db/background_error.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import BackgroundError
|
||||
|
||||
|
||||
def create_background_error(
|
||||
db_session: Session, message: str, cc_pair_id: int | None
|
||||
) -> None:
|
||||
db_session.add(BackgroundError(message=message, cc_pair_id=cc_pair_id))
|
||||
db_session.commit()
|
||||
@@ -350,13 +350,14 @@ def delete_chat_session(
|
||||
user_id: UUID | None,
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.deleted:
|
||||
if chat_session.deleted and not include_deleted:
|
||||
raise ValueError("Cannot delete an already deleted chat session")
|
||||
|
||||
if hard_delete:
|
||||
@@ -380,7 +381,15 @@ def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
|
||||
).fetchall()
|
||||
|
||||
for user_id, session_id in old_sessions:
|
||||
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id, session_id, db_session, include_deleted=True, hard_delete=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_chat_message(
|
||||
@@ -893,14 +902,18 @@ def translate_db_sub_questions_to_server_objects(
|
||||
question=sub_question.sub_question,
|
||||
answer=sub_question.sub_answer,
|
||||
sub_queries=sub_queries,
|
||||
context_docs=get_retrieval_docs_from_search_docs(verified_docs),
|
||||
context_docs=get_retrieval_docs_from_search_docs(
|
||||
verified_docs, sort_by_score=False
|
||||
),
|
||||
)
|
||||
)
|
||||
return sub_questions
|
||||
|
||||
|
||||
def get_retrieval_docs_from_search_docs(
|
||||
search_docs: list[SearchDoc], remove_doc_content: bool = False
|
||||
search_docs: list[SearchDoc],
|
||||
remove_doc_content: bool = False,
|
||||
sort_by_score: bool = True,
|
||||
) -> RetrievalDocs:
|
||||
top_documents = [
|
||||
translate_db_search_doc_to_server_search_doc(
|
||||
@@ -908,7 +921,8 @@ def get_retrieval_docs_from_search_docs(
|
||||
)
|
||||
for db_doc in search_docs
|
||||
]
|
||||
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
|
||||
if sort_by_score:
|
||||
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
|
||||
return RetrievalDocs(top_documents=top_documents)
|
||||
|
||||
|
||||
@@ -1018,7 +1032,7 @@ def log_agent_sub_question_results(
|
||||
sub_question = sub_question_answer_result.question
|
||||
sub_answer = sub_question_answer_result.answer
|
||||
sub_document_results = _create_citation_format_list(
|
||||
sub_question_answer_result.verified_reranked_documents
|
||||
sub_question_answer_result.context_documents
|
||||
)
|
||||
|
||||
sub_question_object = AgentSubQuestion(
|
||||
|
||||
@@ -105,6 +105,32 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
return stmt
|
||||
|
||||
|
||||
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
|
||||
connector_id: int, credential_id: int
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(DbDocument.id)
|
||||
.where(
|
||||
DbDocument.id.in_(initial_doc_ids_stmt),
|
||||
or_(
|
||||
DbDocument.last_modified
|
||||
> DbDocument.last_synced, # last_modified is newer than last_synced
|
||||
DbDocument.last_synced.is_(None), # never synced
|
||||
),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def get_all_documents_needing_vespa_sync_for_cc_pair(
|
||||
db_session: Session, cc_pair_id: int
|
||||
) -> list[DbDocument]:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user