mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
123 Commits
hackathon-
...
refactor-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ccde845e47 | ||
|
|
cad3517f85 | ||
|
|
191577fa19 | ||
|
|
a7d140cb5d | ||
|
|
4ef7e44c95 | ||
|
|
e7bd58cc85 | ||
|
|
dd18291d51 | ||
|
|
9a5ea03cd1 | ||
|
|
eee3054b45 | ||
|
|
5eea47cb1c | ||
|
|
c830364c15 | ||
|
|
04f3ba1f3d | ||
|
|
84f76fbee7 | ||
|
|
00aeb3b280 | ||
|
|
8c30085a9e | ||
|
|
419e82f9f4 | ||
|
|
8330e5d8f4 | ||
|
|
e06c60a1a3 | ||
|
|
e7eef67893 | ||
|
|
b5209edffa | ||
|
|
07ad4dc022 | ||
|
|
06e1a2c1a5 | ||
|
|
083c152878 | ||
|
|
06f11a0a06 | ||
|
|
fabfcddadb | ||
|
|
c1d4b08132 | ||
|
|
f3f47d0709 | ||
|
|
fe26a1bfcc | ||
|
|
554cd0f891 | ||
|
|
f87d3e9849 | ||
|
|
72cdada893 | ||
|
|
c442ebaff6 | ||
|
|
56f16d107e | ||
|
|
0157ae099a | ||
|
|
565fb42457 | ||
|
|
a50a8b4a12 | ||
|
|
4baf4e7d96 | ||
|
|
8b7ab2eb66 | ||
|
|
1f75f3633e | ||
|
|
650884d76a | ||
|
|
8722bdb414 | ||
|
|
71037678c3 | ||
|
|
68de1015e1 | ||
|
|
e2b3a6e144 | ||
|
|
4f04b09efa | ||
|
|
5c4f44d258 | ||
|
|
19652ad60e | ||
|
|
70c96b6ab3 | ||
|
|
65076b916f | ||
|
|
06bc0e51db | ||
|
|
508b456b40 | ||
|
|
bf1e2a2661 | ||
|
|
991d5e4203 | ||
|
|
d21f012b04 | ||
|
|
86b7beab01 | ||
|
|
b4eaa81d8b | ||
|
|
ff2a4c8723 | ||
|
|
51027fd259 | ||
|
|
7e3fd2b12a | ||
|
|
d2fef6f0b7 | ||
|
|
bd06147d26 | ||
|
|
1f3cc9ed6e | ||
|
|
6086d9e51a | ||
|
|
e0de24f64e | ||
|
|
08b6b1f8b3 | ||
|
|
afed1a4b37 | ||
|
|
bca18cacdf | ||
|
|
335db91803 | ||
|
|
67c488ff1f | ||
|
|
deb7f13962 | ||
|
|
e2d3d65c60 | ||
|
|
b78a6834f5 | ||
|
|
4abe90aa2c | ||
|
|
de9568844b | ||
|
|
34268f9806 | ||
|
|
ed75678837 | ||
|
|
3bb58a3dd3 | ||
|
|
4b02feef31 | ||
|
|
6a4d49f02e | ||
|
|
d1736187d3 | ||
|
|
0e79b96091 | ||
|
|
ae302d473d | ||
|
|
feca4fda78 | ||
|
|
f7ed7cd3cd | ||
|
|
8377ab3ef2 | ||
|
|
95c23bf870 | ||
|
|
e49fb8f56d | ||
|
|
adf48de652 | ||
|
|
bca2500438 | ||
|
|
89f925662f | ||
|
|
b64c6d5d40 | ||
|
|
36c63950a6 | ||
|
|
3f31340e6f | ||
|
|
6ac2258c2e | ||
|
|
b4d3b43e8a | ||
|
|
ca281b71e3 | ||
|
|
9bd5a1de7a | ||
|
|
d3c5a4fba0 | ||
|
|
f50006ee63 | ||
|
|
e0092024af | ||
|
|
675ef524b0 | ||
|
|
240367c775 | ||
|
|
f0ed063860 | ||
|
|
bcf0ef0c87 | ||
|
|
0c7a245a46 | ||
|
|
583d82433a | ||
|
|
391e710b6e | ||
|
|
004e56a91b | ||
|
|
103300798f | ||
|
|
8349d6f0ea | ||
|
|
cd63bf6da9 | ||
|
|
5f03e85195 | ||
|
|
cbdbfcab5e | ||
|
|
6918611287 | ||
|
|
b0639add8f | ||
|
|
7af10308d7 | ||
|
|
5e14f23507 | ||
|
|
0bf3a5c609 | ||
|
|
82724826ce | ||
|
|
f9e061926a | ||
|
|
8afd07ff7a | ||
|
|
6523a38255 | ||
|
|
264878a1c9 |
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -1 +1,3 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
# Helm charts Owners
|
||||
/helm/ @justin-tahara
|
||||
|
||||
40
.github/workflows/helm-chart-releases.yml
vendored
Normal file
40
.github/workflows/helm-chart-releases.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Release Onyx Helm Charts
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: write-all
|
||||
|
||||
jobs:
|
||||
release:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config user.name "$GITHUB_ACTOR"
|
||||
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
|
||||
|
||||
- name: Install Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
- name: Add Required Helm Repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Run chart-releaser
|
||||
uses: helm/chart-releaser-action@v1.7.0
|
||||
env:
|
||||
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
|
||||
@@ -13,6 +13,14 @@ env:
|
||||
# MinIO
|
||||
S3_ENDPOINT_URL: "http://localhost:9004"
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
38
.github/workflows/pr-labeler.yml
vendored
Normal file
38
.github/workflows/pr-labeler.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: PR Labeler
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
- synchronize
|
||||
- edited
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
validate_pr_title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR title for Conventional Commits
|
||||
env:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
run: |
|
||||
echo "PR Title: $PR_TITLE"
|
||||
if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then
|
||||
echo "::error::❌ Your PR title does not follow the Conventional Commits format.
|
||||
This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history.
|
||||
|
||||
Please update your PR title to follow the Conventional Commits style.
|
||||
Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits
|
||||
|
||||
**Here are some examples of valid PR titles:**
|
||||
- feat: add user authentication
|
||||
- fix(login): handle null password error
|
||||
- docs(readme): update installation instructions"
|
||||
exit 1
|
||||
fi
|
||||
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client
|
||||
--package-name onyx_openapi_client \
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
||||
@@ -16,8 +16,8 @@ env:
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
@@ -53,6 +53,12 @@ env:
|
||||
# Hubspot
|
||||
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
|
||||
|
||||
# IMAP
|
||||
IMAP_HOST: ${{ secrets.IMAP_HOST }}
|
||||
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
|
||||
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
|
||||
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
|
||||
3
.vscode/env_template.txt
vendored
3
.vscode/env_template.txt
vendored
@@ -45,8 +45,9 @@ PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
EXA_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
|
||||
106
.vscode/launch.template.jsonc
vendored
106
.vscode/launch.template.jsonc
vendored
@@ -24,8 +24,8 @@
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
@@ -46,8 +46,8 @@
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
@@ -226,35 +226,66 @@
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing"
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docprocessing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
"-Q",
|
||||
"docprocessing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
@@ -303,35 +334,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user files indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_files_indexing@%n",
|
||||
"-Q",
|
||||
"user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user files indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
@@ -426,7 +428,7 @@
|
||||
},
|
||||
"args": [
|
||||
"--filename",
|
||||
"generated/openapi.json",
|
||||
"generated/openapi.json"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -59,6 +59,7 @@ Onyx being a fully functional app, relies on some external software, specificall
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
@@ -171,10 +172,10 @@ Otherwise, you can follow the instructions below to run the application for deve
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
@@ -23,7 +23,7 @@ from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
|
||||
POSTGRES_DEFAULT_SCHEMA,
|
||||
TENANT_ID_PREFIX,
|
||||
)
|
||||
from onyx.db.models import Base
|
||||
@@ -271,7 +271,7 @@ async def run_async_migrations() -> None:
|
||||
) = get_schema_options()
|
||||
|
||||
if not schemas and not MULTI_TENANT:
|
||||
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
|
||||
schemas = [POSTGRES_DEFAULT_SCHEMA]
|
||||
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""add federated connector tables
|
||||
|
||||
Revision ID: 0816326d83aa
|
||||
Revises: 12635f6655b7
|
||||
Create Date: 2025-06-29 14:09:45.109518
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0816326d83aa"
|
||||
down_revision = "12635f6655b7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create federated_connector table
|
||||
op.create_table(
|
||||
"federated_connector",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("credentials", sa.LargeBinary(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create federated_connector_oauth_token table
|
||||
op.create_table(
|
||||
"federated_connector_oauth_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("token", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create federated_connector__document_set table
|
||||
op.create_table(
|
||||
"federated_connector__document_set",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_set_id", sa.Integer(), nullable=False),
|
||||
sa.Column("entities", postgresql.JSONB(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_set_id"], ["document_set.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"federated_connector_id",
|
||||
"document_set_id",
|
||||
name="uq_federated_connector_document_set",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order due to foreign key dependencies
|
||||
op.drop_table("federated_connector__document_set")
|
||||
op.drop_table("federated_connector_oauth_token")
|
||||
op.drop_table("federated_connector")
|
||||
596
backend/alembic/versions/12635f6655b7_drive_canonical_ids.py
Normal file
596
backend/alembic/versions/12635f6655b7_drive_canonical_ids.py
Normal file
@@ -0,0 +1,596 @@
|
||||
"""drive-canonical-ids
|
||||
|
||||
Revision ID: 12635f6655b7
|
||||
Revises: 58c50ef19f08
|
||||
Create Date: 2025-06-20 14:44:54.241159
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from httpx import HTTPStatusError
|
||||
import httpx
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.utils.logger import setup_logger
|
||||
import os
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "12635f6655b7"
|
||||
down_revision = "58c50ef19f08"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_CANON_DRIVE_IDS = os.environ.get("SKIP_CANON_DRIVE_IDS", "true").lower() == "true"
|
||||
|
||||
|
||||
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
|
||||
result = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_fetch = result.fetchall()
|
||||
search_settings = (
|
||||
SearchSettings(**search_settings_fetch[0]._asdict())
|
||||
if search_settings_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
result2 = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_future_fetch = result2.fetchall()
|
||||
search_settings_future = (
|
||||
SearchSettings(**search_settings_future_fetch[0]._asdict())
|
||||
if search_settings_future_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
if not isinstance(search_settings, SearchSettings):
|
||||
raise RuntimeError(
|
||||
"current search settings is of type " + str(type(search_settings))
|
||||
)
|
||||
if (
|
||||
not isinstance(search_settings_future, SearchSettings)
|
||||
and search_settings_future is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"future search settings is of type " + str(type(search_settings_future))
|
||||
)
|
||||
|
||||
return search_settings, search_settings_future
|
||||
|
||||
|
||||
def normalize_google_drive_url(url: str) -> str:
|
||||
"""Remove query parameters from Google Drive URLs to create canonical document IDs.
|
||||
NOTE: copied from drive doc_conversion.py
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
parsed_url = parsed_url._replace(query="")
|
||||
spl_path = parsed_url.path.split("/")
|
||||
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
|
||||
spl_path.pop()
|
||||
parsed_url = parsed_url._replace(path="/".join(spl_path))
|
||||
# Remove query parameters and reconstruct URL
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def get_google_drive_documents_from_database() -> list[dict]:
|
||||
"""Get all Google Drive documents from the database."""
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT d.id
|
||||
FROM document d
|
||||
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
|
||||
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
|
||||
AND dcc.credential_id = cc.credential_id
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
WHERE c.source = 'GOOGLE_DRIVE'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
documents = []
|
||||
for row in result:
|
||||
documents.append({"document_id": row.id})
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def update_document_id_in_database(
|
||||
old_doc_id: str, new_doc_id: str, index_name: str
|
||||
) -> None:
|
||||
"""Update document IDs in all relevant database tables using copy-and-swap approach."""
|
||||
bind = op.get_bind()
|
||||
|
||||
# print(f"Updating database tables for document {old_doc_id} -> {new_doc_id}")
|
||||
|
||||
# Check if new document ID already exists
|
||||
result = bind.execute(
|
||||
sa.text("SELECT COUNT(*) FROM document WHERE id = :new_id"),
|
||||
{"new_id": new_doc_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row and row[0] > 0:
|
||||
# print(f"Document with ID {new_doc_id} already exists, deleting old one")
|
||||
delete_document_from_db(old_doc_id, index_name)
|
||||
return
|
||||
|
||||
# Step 1: Create a new document row with the new ID (copy all fields from old row)
|
||||
# Use a conservative approach to handle columns that might not exist in all installations
|
||||
try:
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners,
|
||||
external_user_emails, external_user_group_ids, is_public,
|
||||
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time)
|
||||
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners,
|
||||
external_user_emails, external_user_group_ids, is_public,
|
||||
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time
|
||||
FROM document
|
||||
WHERE id = :old_id
|
||||
"""
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated database tables for document {old_doc_id} -> {new_doc_id}")
|
||||
except Exception as e:
|
||||
# If the full INSERT fails, try a more basic version with only core columns
|
||||
logger.warning(f"Full INSERT failed, trying basic version: {e}")
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners)
|
||||
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
|
||||
link, doc_updated_at, primary_owners, secondary_owners
|
||||
FROM document
|
||||
WHERE id = :old_id
|
||||
"""
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
|
||||
# Step 2: Update all foreign key references to point to the new ID
|
||||
|
||||
# Update document_by_connector_credential_pair table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE document_by_connector_credential_pair SET id = :new_id WHERE id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated document_by_connector_credential_pair table for document {old_doc_id} -> {new_doc_id}")
|
||||
|
||||
# Update search_doc table (stores search results for chat replay)
|
||||
# This is critical for agent functionality
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated search_doc table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update document_retrieval_feedback table (user feedback on documents)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE document_retrieval_feedback SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated document_retrieval_feedback table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update document__tag table (document-tag relationships)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE document__tag SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated document__tag table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update user_file table (user uploaded files linked to documents)
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE user_file SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated user_file table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update KG and chunk_stats tables (these may not exist in all installations)
|
||||
try:
|
||||
# Update kg_entity table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_entity SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_entity table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update kg_entity_extraction_staging table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_entity_extraction_staging SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_entity_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update kg_relationship table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_relationship SET source_document = :new_id WHERE source_document = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_relationship table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update kg_relationship_extraction_staging table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE kg_relationship_extraction_staging SET source_document = :new_id WHERE source_document = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated kg_relationship_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update chunk_stats table
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE chunk_stats SET document_id = :new_id WHERE document_id = :old_id"
|
||||
),
|
||||
{"new_id": new_doc_id, "old_id": old_doc_id},
|
||||
)
|
||||
# print(f"Successfully updated chunk_stats table for document {old_doc_id} -> {new_doc_id}")
|
||||
# Update chunk_stats ID field which includes document_id
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chunk_stats
|
||||
SET id = REPLACE(id, :old_id, :new_id)
|
||||
WHERE id LIKE :old_id_pattern
|
||||
"""
|
||||
),
|
||||
{
|
||||
"new_id": new_doc_id,
|
||||
"old_id": old_doc_id,
|
||||
"old_id_pattern": f"{old_doc_id}__%",
|
||||
},
|
||||
)
|
||||
# print(f"Successfully updated chunk_stats ID field for document {old_doc_id} -> {new_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Some KG/chunk tables may not exist or failed to update: {e}")
|
||||
|
||||
# Step 3: Delete the old document row (this should now be safe since all FKs point to new row)
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM document WHERE id = :old_id"), {"old_id": old_doc_id}
|
||||
)
|
||||
# print(f"Successfully deleted document {old_doc_id} from database")
|
||||
|
||||
|
||||
def _visit_chunks(
|
||||
*,
|
||||
http_client: httpx.Client,
|
||||
index_name: str,
|
||||
selection: str,
|
||||
continuation: str | None = None,
|
||||
) -> tuple[list[dict], str | None]:
|
||||
"""Helper that calls the /document/v1 visit API once and returns (docs, next_token)."""
|
||||
|
||||
# Use the same URL as the document API, but with visit-specific params
|
||||
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": "1000",
|
||||
}
|
||||
if continuation:
|
||||
params["continuation"] = continuation
|
||||
|
||||
# print(f"Visiting chunks for selection '{selection}' with params {params}")
|
||||
resp = http_client.get(base_url, params=params, timeout=None)
|
||||
# print(f"Visited chunks for document {selection}")
|
||||
resp.raise_for_status()
|
||||
|
||||
payload = resp.json()
|
||||
return payload.get("documents", []), payload.get("continuation")
|
||||
|
||||
|
||||
def delete_document_chunks_from_vespa(index_name: str, doc_id: str) -> None:
|
||||
"""Delete all chunks for *doc_id* from Vespa using continuation-token paging (no offset)."""
|
||||
|
||||
total_deleted = 0
|
||||
# Use exact match instead of contains - Document Selector Language doesn't support contains
|
||||
selection = f'{index_name}.document_id=="{doc_id}"'
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
|
||||
if not docs:
|
||||
break
|
||||
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
continue
|
||||
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
delete_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
|
||||
try:
|
||||
resp = http_client.delete(delete_url)
|
||||
resp.raise_for_status()
|
||||
total_deleted += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to delete chunk {vespa_doc_uuid}: {e}")
|
||||
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
|
||||
def update_document_id_in_vespa(
|
||||
index_name: str, old_doc_id: str, new_doc_id: str
|
||||
) -> None:
|
||||
"""Update all chunks' document_id field from *old_doc_id* to *new_doc_id* using continuation paging."""
|
||||
|
||||
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
|
||||
|
||||
# Use exact match instead of contains - Document Selector Language doesn't support contains
|
||||
selection = f'{index_name}.document_id=="{old_doc_id}"'
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
# print(f"Visiting chunks for document {old_doc_id} -> {new_doc_id}")
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
|
||||
if not docs:
|
||||
break
|
||||
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
continue
|
||||
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
|
||||
update_request = {
|
||||
"fields": {"document_id": {"assign": clean_new_doc_id}}
|
||||
}
|
||||
|
||||
try:
|
||||
resp = http_client.put(vespa_url, json=update_request)
|
||||
resp.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Failed to update chunk {vespa_doc_uuid}: {e}")
|
||||
raise
|
||||
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
|
||||
def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
# Delete all foreign key references first, then delete the document
|
||||
try:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Delete from agent-related tables first (order matters due to foreign keys)
|
||||
# Delete from agent__sub_query__search_doc first since it references search_doc
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM agent__sub_query__search_doc
|
||||
WHERE search_doc_id IN (
|
||||
SELECT id FROM search_doc WHERE document_id = :doc_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from chat_message__search_doc
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE search_doc_id IN (
|
||||
SELECT id FROM search_doc WHERE document_id = :doc_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Now we can safely delete from search_doc
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from document_by_connector_credential_pair
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM document_by_connector_credential_pair WHERE id = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from other tables that reference this document
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM document__tag WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM user_file WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from KG tables if they exist
|
||||
try:
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM kg_entity WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM kg_entity_extraction_staging WHERE document_id = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM kg_relationship WHERE source_document = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM kg_relationship_extraction_staging WHERE source_document = :doc_id"
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM chunk_stats WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM chunk_stats WHERE id LIKE :doc_id_pattern"),
|
||||
{"doc_id_pattern": f"{current_doc_id}__%"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Some KG/chunk tables may not exist or failed to delete from: {e}"
|
||||
)
|
||||
|
||||
# Finally delete the document itself
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM document WHERE id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete chunks from vespa
|
||||
delete_document_chunks_from_vespa(index_name, current_doc_id)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to delete duplicate document {current_doc_id}: {e}")
|
||||
# Continue with other documents instead of failing the entire migration
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if SKIP_CANON_DRIVE_IDS:
|
||||
return
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
future_search_settings,
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
# Get all Google Drive documents from the database (this is faster and more reliable)
|
||||
gdrive_documents = get_google_drive_documents_from_database()
|
||||
|
||||
if not gdrive_documents:
|
||||
return
|
||||
|
||||
# Track normalized document IDs to detect duplicates
|
||||
all_normalized_doc_ids = set()
|
||||
updated_count = 0
|
||||
|
||||
for doc_info in gdrive_documents:
|
||||
current_doc_id = doc_info["document_id"]
|
||||
normalized_doc_id = normalize_google_drive_url(current_doc_id)
|
||||
|
||||
print(f"Processing document {current_doc_id} -> {normalized_doc_id}")
|
||||
# Check for duplicates
|
||||
if normalized_doc_id in all_normalized_doc_ids:
|
||||
# print(f"Deleting duplicate document {current_doc_id}")
|
||||
delete_document_from_db(current_doc_id, index_name)
|
||||
continue
|
||||
|
||||
all_normalized_doc_ids.add(normalized_doc_id)
|
||||
|
||||
# If the document ID already doesn't have query parameters, skip it
|
||||
if current_doc_id == normalized_doc_id:
|
||||
# print(f"Skipping document {current_doc_id} -> {normalized_doc_id} because it already has no query parameters")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Update both database and Vespa in order
|
||||
# Database first to ensure consistency
|
||||
update_document_id_in_database(
|
||||
current_doc_id, normalized_doc_id, index_name
|
||||
)
|
||||
|
||||
# For Vespa, we can now use the original document IDs since we're using contains matching
|
||||
update_document_id_in_vespa(index_name, current_doc_id, normalized_doc_id)
|
||||
updated_count += 1
|
||||
# print(f"Finished updating document {current_doc_id} -> {normalized_doc_id}")
|
||||
except Exception as e:
|
||||
print(f"Failed to update document {current_doc_id}: {e}")
|
||||
|
||||
if isinstance(e, HTTPStatusError):
|
||||
print(f"HTTPStatusError: {e}")
|
||||
print(f"Response: {e.response.text}")
|
||||
print(f"Status: {e.response.status_code}")
|
||||
print(f"Headers: {e.response.headers}")
|
||||
print(f"Request: {e.request.url}")
|
||||
print(f"Request headers: {e.request.headers}")
|
||||
# Note: Rollback is complex with copy-and-swap approach since the old document is already deleted
|
||||
# In case of failure, manual intervention may be required
|
||||
# Continue with other documents instead of failing the entire migration
|
||||
continue
|
||||
|
||||
logger.info(f"Migration complete. Updated {updated_count} Google Drive documents")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# this is a one way migration, so no downgrade.
|
||||
# It wouldn't make sense to store the extra query parameters
|
||||
# and duplicate documents to allow a reversal.
|
||||
pass
|
||||
@@ -144,27 +144,34 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("TRUNCATE TABLE index_attempt")
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"connector_specific_config",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
existing_columns = {col["name"] for col in inspector.get_columns("index_attempt")}
|
||||
|
||||
if "input_type" not in existing_columns:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
|
||||
if "source" not in existing_columns:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
)
|
||||
|
||||
if "connector_specific_config" not in existing_columns:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"connector_specific_config",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
constraints = inspector.get_foreign_keys("index_attempt")
|
||||
|
||||
if any(
|
||||
@@ -183,8 +190,12 @@ def downgrade() -> None:
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_table("connector_credential_pair")
|
||||
op.drop_table("credential")
|
||||
op.drop_table("connector")
|
||||
if "credential_id" in existing_columns:
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
if "connector_id" in existing_columns:
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS connector_credential_pair CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS credential CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS connector CASCADE")
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
"""add_indexing_coordination
|
||||
|
||||
Revision ID: 2f95e36923e6
|
||||
Revises: 0816326d83aa
|
||||
Create Date: 2025-07-10 16:17:57.762182
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f95e36923e6"
|
||||
down_revision = "0816326d83aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add database-based coordination fields (replacing Redis fencing)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"cancellation_requested",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
# Add batch coordination fields (replacing FileStore state)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"completed_batches", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"total_failures_batch_level",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Progress tracking for stall detection
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"last_batches_completed_count",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# Heartbeat tracking for worker liveness detection
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Add index for coordination queries
|
||||
op.create_index(
|
||||
"ix_index_attempt_active_coordination",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "search_settings_id", "status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the new index
|
||||
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
|
||||
|
||||
# Remove the new columns
|
||||
op.drop_column("index_attempt", "last_batches_completed_count")
|
||||
op.drop_column("index_attempt", "last_progress_time")
|
||||
op.drop_column("index_attempt", "last_heartbeat_time")
|
||||
op.drop_column("index_attempt", "last_heartbeat_value")
|
||||
op.drop_column("index_attempt", "heartbeat_counter")
|
||||
op.drop_column("index_attempt", "total_chunks")
|
||||
op.drop_column("index_attempt", "total_failures_batch_level")
|
||||
op.drop_column("index_attempt", "completed_batches")
|
||||
op.drop_column("index_attempt", "total_batches")
|
||||
op.drop_column("index_attempt", "cancellation_requested")
|
||||
op.drop_column("index_attempt", "celery_task_id")
|
||||
@@ -9,7 +9,7 @@ Create Date: 2025-06-22 17:33:25.833733
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "36e9220ab794"
|
||||
@@ -66,7 +66,7 @@ def upgrade() -> None:
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -111,7 +111,7 @@ def upgrade() -> None:
|
||||
UPDATE "{tenant_id}".kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add_doc_metadata_field_in_document_model
|
||||
|
||||
Revision ID: 3fc5d75723b3
|
||||
Revises: 2f95e36923e6
|
||||
Create Date: 2025-07-28 18:45:37.985406
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3fc5d75723b3"
|
||||
down_revision = "2f95e36923e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "doc_metadata")
|
||||
@@ -15,7 +15,7 @@ from datetime import datetime, timedelta
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -80,6 +80,7 @@ def upgrade() -> None:
|
||||
)
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_config CASCADE")
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
@@ -123,6 +124,7 @@ def upgrade() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_entity_type CASCADE")
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
@@ -156,6 +158,7 @@ def upgrade() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship_type CASCADE")
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
@@ -194,6 +197,7 @@ def upgrade() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship_type_extraction_staging CASCADE")
|
||||
# Create KGRelationshipTypeExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_type_extraction_staging",
|
||||
@@ -227,6 +231,8 @@ def upgrade() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_entity CASCADE")
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
@@ -281,6 +287,7 @@ def upgrade() -> None:
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_entity_extraction_staging CASCADE")
|
||||
# Create KGEntityExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_entity_extraction_staging",
|
||||
@@ -330,6 +337,7 @@ def upgrade() -> None:
|
||||
["name", "entity_type_id_name"],
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship CASCADE")
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
@@ -371,6 +379,7 @@ def upgrade() -> None:
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_relationship_extraction_staging CASCADE")
|
||||
# Create KGRelationshipExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_extraction_staging",
|
||||
@@ -414,6 +423,7 @@ def upgrade() -> None:
|
||||
["source_node", "target_node"],
|
||||
)
|
||||
|
||||
op.execute("DROP TABLE IF EXISTS kg_term CASCADE")
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
@@ -468,7 +478,7 @@ def upgrade() -> None:
|
||||
# Create GIN index for clustering and normalization
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
|
||||
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
|
||||
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
@@ -508,7 +518,7 @@ def upgrade() -> None:
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -553,7 +563,7 @@ def upgrade() -> None:
|
||||
UPDATE kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
"""add file names to file connector config
|
||||
|
||||
Revision ID: 62c3a055a141
|
||||
Revises: 3fc5d75723b3
|
||||
Create Date: 2025-07-30 17:01:24.417551
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "62c3a055a141"
|
||||
down_revision = "3fc5d75723b3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_FILE_NAME_MIGRATION = (
|
||||
os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true"
|
||||
)
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if SKIP_FILE_NAME_MIGRATION:
|
||||
logger.info(
|
||||
"Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration"
|
||||
)
|
||||
return
|
||||
logger.info("Running file name migration")
|
||||
# Get connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Get all FILE connectors with their configs
|
||||
file_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'FILE'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for connector_id, config in file_connectors:
|
||||
# Parse config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
# Get file_locations list
|
||||
file_locations = config.get("file_locations", [])
|
||||
|
||||
# Get display names for each file_id
|
||||
file_names = []
|
||||
for file_id in file_locations:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT display_name
|
||||
FROM file_record
|
||||
WHERE file_id = :file_id
|
||||
"""
|
||||
),
|
||||
{"file_id": file_id},
|
||||
).fetchone()
|
||||
|
||||
if result:
|
||||
file_names.append(result[0])
|
||||
else:
|
||||
file_names.append(file_id) # Should not happen
|
||||
|
||||
# Add file_names to config
|
||||
new_config = dict(config)
|
||||
new_config["file_names"] = file_names
|
||||
|
||||
# Update the connector
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Remove file_names from all FILE connectors
|
||||
file_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'FILE'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for connector_id, config in file_connectors:
|
||||
# Parse config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
# Remove file_names if it exists
|
||||
if "file_names" in config:
|
||||
new_config = dict(config)
|
||||
del new_config["file_names"]
|
||||
|
||||
# Update the connector
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"connector_id": connector_id,
|
||||
"new_config": json.dumps(new_config),
|
||||
},
|
||||
)
|
||||
@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
|
||||
|
||||
# only create external store if we have files to migrate. This line
|
||||
# makes it so we need to have S3/MinIO configured to run this migration.
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
# Get database session
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
|
||||
result = session.execute(
|
||||
@@ -236,6 +236,9 @@ def _migrate_files_to_external_storage() -> None:
|
||||
print("No files found in PostgreSQL storage to migrate.")
|
||||
return
|
||||
|
||||
# might need to move this above the if statement when creating a new multi-tenant
|
||||
# system. VERY extreme edge case.
|
||||
external_store.initialize()
|
||||
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
|
||||
|
||||
_set_tenant_contextvar(session)
|
||||
|
||||
@@ -18,11 +18,13 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS document CASCADE")
|
||||
op.create_table(
|
||||
"document",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS chunk CASCADE")
|
||||
op.create_table(
|
||||
"chunk",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
@@ -43,6 +45,7 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "document_store_type"),
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS deletion_attempt CASCADE")
|
||||
op.create_table(
|
||||
"deletion_attempt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
@@ -84,6 +87,7 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS document_by_connector_credential_pair CASCADE")
|
||||
op.create_table(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
@@ -106,7 +110,10 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# upstream tables first
|
||||
op.drop_table("document_by_connector_credential_pair")
|
||||
op.drop_table("deletion_attempt")
|
||||
op.drop_table("chunk")
|
||||
op.drop_table("document")
|
||||
|
||||
# Alembic op.drop_table() has no "cascade" flag – issue raw SQL
|
||||
op.execute("DROP TABLE IF EXISTS document CASCADE")
|
||||
|
||||
@@ -91,7 +91,7 @@ def export_query_history_task(
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store(db_session).save_file(
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
@@ -58,7 +59,9 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import is_retryable_sqlalchemy_error
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
@@ -422,7 +425,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
@@ -498,16 +501,31 @@ def connector_permission_sync_generator_task(
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
# should no longer be accessible. The decision as to whether we should find
|
||||
# every document during the doc sync process is connector-specific.
|
||||
def fetch_all_existing_docs_fn() -> list[str]:
|
||||
return get_document_ids_for_connector_credential_pair(
|
||||
def fetch_all_existing_docs_fn(
|
||||
sort_order: SortOrder | None = None,
|
||||
) -> list[DocumentRow]:
|
||||
result = get_documents_for_connector_credential_pair_limited_columns(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
return list(result)
|
||||
|
||||
def fetch_all_existing_docs_ids_fn() -> list[str]:
|
||||
result = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
return result
|
||||
|
||||
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
|
||||
document_external_accesses = doc_sync_func(
|
||||
cc_pair, fetch_all_existing_docs_fn, callback
|
||||
cc_pair,
|
||||
fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn,
|
||||
callback,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
|
||||
@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -71,6 +71,19 @@ GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# GitHub
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Slack
|
||||
#####
|
||||
|
||||
@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
|
||||
|
||||
|
||||
def get_usage_report_data(
|
||||
db_session: Session,
|
||||
report_display_name: str,
|
||||
) -> IO:
|
||||
"""
|
||||
@@ -128,7 +127,7 @@ def get_usage_report_data(
|
||||
Returns:
|
||||
The usage report data.
|
||||
"""
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
# usage report may be very large, so don't load it all into memory
|
||||
return file_store.read_file(
|
||||
file_id=report_display_name, mode="b", use_tempfile=True
|
||||
|
||||
@@ -128,11 +128,14 @@ def validate_object_creation_for_user(
|
||||
target_group_ids: list[int] | None = None,
|
||||
object_is_public: bool | None = None,
|
||||
object_is_perm_sync: bool | None = None,
|
||||
object_is_owned_by_user: bool = False,
|
||||
object_is_new: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
All users can create/edit permission synced objects if they don't specify a group
|
||||
All admin actions are allowed.
|
||||
Prevents non-admins from creating/editing:
|
||||
Curators and global curators can create public objects.
|
||||
Prevents other non-admins from creating/editing:
|
||||
- public objects
|
||||
- objects with no groups
|
||||
- objects that belong to a group they don't curate
|
||||
@@ -143,13 +146,23 @@ def validate_object_creation_for_user(
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
if object_is_public:
|
||||
detail = "User does not have permission to create public credentials"
|
||||
# Allow curators and global curators to create public objects
|
||||
# w/o associated groups IF the object is new/owned by them
|
||||
if (
|
||||
object_is_public
|
||||
and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR]
|
||||
and (object_is_new or object_is_owned_by_user)
|
||||
):
|
||||
return
|
||||
|
||||
if object_is_public and user.role == UserRole.BASIC:
|
||||
detail = "User does not have permission to create public objects"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
if not target_group_ids:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
|
||||
@@ -18,9 +18,9 @@
|
||||
<!-- <document type="danswer_chunk" mode="index" /> -->
|
||||
{{ document_elements }}
|
||||
</documents>
|
||||
<nodes count="75">
|
||||
<resources vcpu="8.0" memory="64.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="474.0Gb" />
|
||||
<nodes count="60">
|
||||
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="475.0Gb" />
|
||||
</nodes>
|
||||
<engine>
|
||||
<proton>
|
||||
|
||||
@@ -6,6 +6,7 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -25,6 +26,7 @@ CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -43,7 +45,7 @@ def confluence_doc_sync(
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.CONFLUENCE,
|
||||
slim_connector=confluence_connector,
|
||||
|
||||
294
backend/ee/onyx/external_permissions/github/doc_sync.py
Normal file
294
backend/ee/onyx/external_permissions/github/doc_sync.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from github import Github
|
||||
from github.Repository import Repository
|
||||
|
||||
from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs
|
||||
from ee.onyx.external_permissions.github.utils import form_collaborators_group_id
|
||||
from ee.onyx.external_permissions.github.utils import form_organization_group_id
|
||||
from ee.onyx.external_permissions.github.utils import (
|
||||
form_outside_collaborators_group_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.github.utils import get_external_access_permission
|
||||
from ee.onyx.external_permissions.github.utils import get_repository_visibility
|
||||
from ee.onyx.external_permissions.github.utils import GitHubVisibility
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.connector import DocMetadata
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GITHUB_DOC_SYNC_LABEL = "github_doc_sync"
|
||||
|
||||
|
||||
def github_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Sync GitHub documents with external access permissions.
|
||||
|
||||
This function checks each repository for visibility/team changes and updates
|
||||
document permissions accordingly without using checkpoints.
|
||||
"""
|
||||
logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}")
|
||||
|
||||
# Initialize GitHub connector with credentials
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
logger.info("GitHub connector credentials loaded successfully")
|
||||
|
||||
if not github_connector.github_client:
|
||||
logger.error("GitHub client initialization failed")
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
# Get all repositories from GitHub API
|
||||
logger.info("Fetching all repositories from GitHub API")
|
||||
try:
|
||||
repos = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(
|
||||
github_connector.github_client
|
||||
)
|
||||
else:
|
||||
# Single repository
|
||||
repos = [
|
||||
github_connector.get_github_repo(github_connector.github_client)
|
||||
]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
logger.info(f"Found {len(repos)} repositories to check")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch repositories: {e}")
|
||||
raise
|
||||
|
||||
repo_to_doc_list_map: dict[str, list[DocumentRow]] = {}
|
||||
# sort order is ascending because we want to get the oldest documents first
|
||||
existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn(
|
||||
sort_order=SortOrder.ASC
|
||||
)
|
||||
logger.info(f"Found {len(existing_docs)} documents to check")
|
||||
for doc in existing_docs:
|
||||
try:
|
||||
doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata))
|
||||
if doc_metadata.repo not in repo_to_doc_list_map:
|
||||
repo_to_doc_list_map[doc_metadata.repo] = []
|
||||
repo_to_doc_list_map[doc_metadata.repo].append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}")
|
||||
continue
|
||||
logger.info(f"Found {len(repo_to_doc_list_map)} documents to check")
|
||||
# Process each repository individually
|
||||
for repo in repos:
|
||||
try:
|
||||
logger.info(f"Processing repository: {repo.id} (name: {repo.name})")
|
||||
repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get(
|
||||
repo.full_name, []
|
||||
)
|
||||
if not repo_doc_list:
|
||||
logger.warning(
|
||||
f"No documents found for repository {repo.id} ({repo.name})"
|
||||
)
|
||||
continue
|
||||
|
||||
current_external_group_ids = repo_doc_list[0].external_user_group_ids or []
|
||||
# Check if repository has any permission changes
|
||||
has_changes = _check_repository_for_changes(
|
||||
repo=repo,
|
||||
github_client=github_connector.github_client,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
)
|
||||
|
||||
if has_changes:
|
||||
logger.info(
|
||||
f"Repository {repo.id} ({repo.name}) has changes, updating documents"
|
||||
)
|
||||
|
||||
# Get new external access permissions for this repository
|
||||
new_external_access = get_external_access_permission(
|
||||
repo, github_connector.github_client
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(repo_doc_list)} documents for repository {repo.full_name}"
|
||||
)
|
||||
|
||||
# Yield updated external access for each document
|
||||
for doc in repo_doc_list:
|
||||
if callback:
|
||||
callback.progress(GITHUB_DOC_SYNC_LABEL, 1)
|
||||
|
||||
yield DocExternalAccess(
|
||||
doc_id=doc.id,
|
||||
external_access=new_external_access,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Repository {repo.id} ({repo.name}) has no changes, skipping"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
|
||||
|
||||
logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}")
|
||||
|
||||
|
||||
def _check_repository_for_changes(
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository has any permission changes (visibility or team updates).
|
||||
"""
|
||||
logger.info(f"Checking repository {repo.id} ({repo.name}) for changes")
|
||||
|
||||
# Check for repository visibility changes using the sample document data
|
||||
if _is_repo_visibility_changed_from_groups(
|
||||
repo=repo,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
):
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes")
|
||||
return True
|
||||
|
||||
# Check for team membership changes if repository is private
|
||||
if get_repository_visibility(
|
||||
repo
|
||||
) == GitHubVisibility.PRIVATE and _teams_updated_from_groups(
|
||||
repo=repo,
|
||||
github_client=github_client,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
):
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has team changes")
|
||||
return True
|
||||
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has no changes")
|
||||
return False
|
||||
|
||||
|
||||
def _is_repo_visibility_changed_from_groups(
|
||||
repo: Repository,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository visibility has changed by analyzing existing external group IDs.
|
||||
|
||||
Args:
|
||||
repo: GitHub repository object
|
||||
current_external_group_ids: List of external group IDs from existing document
|
||||
|
||||
Returns:
|
||||
True if visibility has changed
|
||||
"""
|
||||
current_repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(f"Current repository visibility: {current_repo_visibility.value}")
|
||||
|
||||
# Build expected group IDs for current visibility
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_collaborators_group_id(repo.id),
|
||||
)
|
||||
|
||||
org_group_id = None
|
||||
if repo.organization:
|
||||
org_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_organization_group_id(repo.organization.id),
|
||||
)
|
||||
|
||||
# Determine existing visibility from group IDs
|
||||
has_collaborators_group = collaborators_group_id in current_external_group_ids
|
||||
has_org_group = org_group_id and org_group_id in current_external_group_ids
|
||||
|
||||
if has_collaborators_group:
|
||||
existing_repo_visibility = GitHubVisibility.PRIVATE
|
||||
elif has_org_group:
|
||||
existing_repo_visibility = GitHubVisibility.INTERNAL
|
||||
else:
|
||||
existing_repo_visibility = GitHubVisibility.PUBLIC
|
||||
|
||||
logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}")
|
||||
|
||||
visibility_changed = existing_repo_visibility != current_repo_visibility
|
||||
if visibility_changed:
|
||||
logger.info(
|
||||
f"Visibility changed for repo {repo.id} ({repo.name}): "
|
||||
f"{existing_repo_visibility.value} -> {current_repo_visibility.value}"
|
||||
)
|
||||
|
||||
return visibility_changed
|
||||
|
||||
|
||||
def _teams_updated_from_groups(
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository team memberships have changed using existing group IDs.
|
||||
"""
|
||||
# Fetch current team slugs for the repository
|
||||
current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client)
|
||||
logger.info(
|
||||
f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}"
|
||||
)
|
||||
|
||||
# Build group IDs to exclude from team comparison (non-team groups)
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_collaborators_group_id(repo.id),
|
||||
)
|
||||
outside_collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_outside_collaborators_group_id(repo.id),
|
||||
)
|
||||
non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id}
|
||||
|
||||
# Extract existing team IDs from current external group IDs
|
||||
existing_team_ids = set()
|
||||
for group_id in current_external_group_ids:
|
||||
# Skip all non-team groups, keep only team groups
|
||||
if group_id not in non_team_group_ids:
|
||||
existing_team_ids.add(group_id)
|
||||
|
||||
# Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug")
|
||||
# but current_teams from API are raw team slugs, so we need to add the prefix
|
||||
current_team_ids = set()
|
||||
for team_slug in current_teams:
|
||||
team_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=team_slug,
|
||||
)
|
||||
current_team_ids.add(team_group_id)
|
||||
|
||||
logger.info(
|
||||
f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}"
|
||||
)
|
||||
|
||||
# Compare actual team IDs to detect changes
|
||||
teams_changed = current_team_ids != existing_team_ids
|
||||
if teams_changed:
|
||||
logger.info(
|
||||
f"Team changes detected for repo {repo.id} (name: {repo.name}): "
|
||||
f"existing={existing_team_ids}, current={current_team_ids}"
|
||||
)
|
||||
|
||||
return teams_changed
|
||||
46
backend/ee/onyx/external_permissions/github/group_sync.py
Normal file
46
backend/ee/onyx/external_permissions/github/group_sync.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from github import Repository
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.github.utils import get_external_user_group
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def github_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if not github_connector.github_client:
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
logger.info("Starting GitHub group sync...")
|
||||
repos: list[Repository.Repository] = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(github_connector.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [github_connector.get_github_repo(github_connector.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
for repo in repos:
|
||||
try:
|
||||
for external_group in get_external_user_group(
|
||||
repo, github_connector.github_client
|
||||
):
|
||||
logger.info(f"External group: {external_group}")
|
||||
yield external_group
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
|
||||
488
backend/ee/onyx/external_permissions/github/utils.py
Normal file
488
backend/ee/onyx/external_permissions/github/utils.py
Normal file
@@ -0,0 +1,488 @@
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
from github import Github
|
||||
from github import RateLimitExceededException
|
||||
from github.GithubException import GithubException
|
||||
from github.NamedUser import NamedUser
|
||||
from github.Organization import Organization
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.Repository import Repository
|
||||
from github.Team import Team
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GitHubVisibility(Enum):
|
||||
"""GitHub repository visibility options."""
|
||||
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
INTERNAL = "internal"
|
||||
|
||||
|
||||
MAX_RETRY_COUNT = 3
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Higher-order function to wrap GitHub operations with retry and exception handling
|
||||
|
||||
|
||||
def _run_with_retry(
|
||||
operation: Callable[[], T],
|
||||
description: str,
|
||||
github_client: Github,
|
||||
retry_count: int = 0,
|
||||
) -> Optional[T]:
|
||||
"""Execute a GitHub operation with retry on rate limit and exception handling."""
|
||||
logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}")
|
||||
try:
|
||||
result = operation()
|
||||
logger.debug(f"Operation '{description}' completed successfully")
|
||||
return result
|
||||
except RateLimitExceededException:
|
||||
if retry_count < MAX_RETRY_COUNT:
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
logger.warning(
|
||||
f"Rate limit exceeded while {description}. Retrying... "
|
||||
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
|
||||
)
|
||||
return _run_with_retry(
|
||||
operation, description, github_client, retry_count + 1
|
||||
)
|
||||
else:
|
||||
error_msg = f"Max retries exceeded for {description}"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
except GithubException as e:
|
||||
logger.warning(f"GitHub API error during {description}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during {description}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""Represents a GitHub user with their basic information."""
|
||||
|
||||
login: str
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
class TeamInfo(BaseModel):
|
||||
"""Represents a GitHub team with its members."""
|
||||
|
||||
name: str
|
||||
slug: str
|
||||
members: List[UserInfo]
|
||||
|
||||
|
||||
def _fetch_organization_members(
|
||||
github_client: Github, org_name: str, retry_count: int = 0
|
||||
) -> List[UserInfo]:
|
||||
"""Fetch all organization members including owners and regular members."""
|
||||
org_members: List[UserInfo] = []
|
||||
logger.info(f"Fetching organization members for {org_name}")
|
||||
|
||||
org = _run_with_retry(
|
||||
lambda: github_client.get_organization(org_name),
|
||||
f"get organization {org_name}",
|
||||
github_client,
|
||||
)
|
||||
if not org:
|
||||
logger.error(f"Failed to fetch organization {org_name}")
|
||||
raise RuntimeError(f"Failed to fetch organization {org_name}")
|
||||
|
||||
member_objs: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: org.get_members(filter_="all"),
|
||||
f"get members for organization {org_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for member in member_objs:
|
||||
user_info = UserInfo(login=member.login, name=member.name, email=member.email)
|
||||
org_members.append(user_info)
|
||||
|
||||
logger.info(f"Fetched {len(org_members)} members for organization {org_name}")
|
||||
return org_members
|
||||
|
||||
|
||||
def _fetch_repository_teams_detailed(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0
|
||||
) -> List[TeamInfo]:
|
||||
"""Fetch teams with access to the repository and their members."""
|
||||
teams_data: List[TeamInfo] = []
|
||||
logger.info(f"Fetching teams for repository {repo.full_name}")
|
||||
|
||||
team_objs: PaginatedList[Team] | list[Team] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_teams(),
|
||||
f"get teams for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for team in team_objs:
|
||||
logger.info(
|
||||
f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}"
|
||||
)
|
||||
|
||||
members: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: team.get_members(),
|
||||
f"get members for team {team.name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
team_members = []
|
||||
for m in members:
|
||||
user_info = UserInfo(login=m.login, name=m.name, email=m.email)
|
||||
team_members.append(user_info)
|
||||
|
||||
team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members)
|
||||
teams_data.append(team_info)
|
||||
logger.info(f"Team {team.name} has {len(team_members)} members")
|
||||
|
||||
logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}")
|
||||
return teams_data
|
||||
|
||||
|
||||
def fetch_repository_team_slugs(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0
|
||||
) -> List[str]:
|
||||
"""Fetch team slugs with access to the repository."""
|
||||
logger.info(f"Fetching team slugs for repository {repo.full_name}")
|
||||
teams_data: List[str] = []
|
||||
|
||||
team_objs: PaginatedList[Team] | list[Team] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_teams(),
|
||||
f"get teams for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for team in team_objs:
|
||||
teams_data.append(team.slug)
|
||||
|
||||
logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}")
|
||||
return teams_data
|
||||
|
||||
|
||||
def _get_collaborators_and_outside_collaborators(
|
||||
github_client: Github,
|
||||
repo: Repository,
|
||||
) -> Tuple[List[UserInfo], List[UserInfo]]:
|
||||
"""Fetch and categorize collaborators into regular and outside collaborators."""
|
||||
collaborators: List[UserInfo] = []
|
||||
outside_collaborators: List[UserInfo] = []
|
||||
logger.info(f"Fetching collaborators for repository {repo.full_name}")
|
||||
|
||||
repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_collaborators(),
|
||||
f"get collaborators for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for collaborator in repo_collaborators:
|
||||
is_outside = False
|
||||
|
||||
# Check if collaborator is outside the organization
|
||||
if repo.organization:
|
||||
org: Organization | None = _run_with_retry(
|
||||
lambda: github_client.get_organization(repo.organization.login),
|
||||
f"get organization {repo.organization.login}",
|
||||
github_client,
|
||||
)
|
||||
|
||||
if org is not None:
|
||||
org_obj = org
|
||||
membership = _run_with_retry(
|
||||
lambda: org_obj.has_in_members(collaborator),
|
||||
f"check membership for {collaborator.login} in org {org_obj.login}",
|
||||
github_client,
|
||||
)
|
||||
is_outside = membership is not None and not membership
|
||||
|
||||
info = UserInfo(
|
||||
login=collaborator.login, name=collaborator.name, email=collaborator.email
|
||||
)
|
||||
if repo.organization and is_outside:
|
||||
outside_collaborators.append(info)
|
||||
else:
|
||||
collaborators.append(info)
|
||||
|
||||
logger.info(
|
||||
f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}"
|
||||
)
|
||||
return collaborators, outside_collaborators
|
||||
|
||||
|
||||
def form_collaborators_group_id(repository_id: int) -> str:
|
||||
"""Generate group ID for repository collaborators."""
|
||||
if not repository_id:
|
||||
logger.exception("Repository ID is required to generate collaborators group ID")
|
||||
raise ValueError("Repository ID must be set to generate group ID.")
|
||||
group_id = f"{repository_id}_collaborators"
|
||||
return group_id
|
||||
|
||||
|
||||
def form_organization_group_id(organization_id: int) -> str:
|
||||
"""Generate group ID for organization using organization ID."""
|
||||
if not organization_id:
|
||||
logger.exception(
|
||||
"Organization ID is required to generate organization group ID"
|
||||
)
|
||||
raise ValueError("Organization ID must be set to generate group ID.")
|
||||
group_id = f"{organization_id}_organization"
|
||||
return group_id
|
||||
|
||||
|
||||
def form_outside_collaborators_group_id(repository_id: int) -> str:
|
||||
"""Generate group ID for outside collaborators."""
|
||||
if not repository_id:
|
||||
logger.exception(
|
||||
"Repository ID is required to generate outside collaborators group ID"
|
||||
)
|
||||
raise ValueError("Repository ID must be set to generate group ID.")
|
||||
group_id = f"{repository_id}_outside_collaborators"
|
||||
return group_id
|
||||
|
||||
|
||||
def get_repository_visibility(repo: Repository) -> GitHubVisibility:
|
||||
"""
|
||||
Get the visibility of a repository.
|
||||
Returns GitHubVisibility enum member.
|
||||
"""
|
||||
if hasattr(repo, "visibility"):
|
||||
visibility = repo.visibility
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} visibility from attribute: {visibility}"
|
||||
)
|
||||
try:
|
||||
return GitHubVisibility(visibility)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private"
|
||||
)
|
||||
return GitHubVisibility.PRIVATE
|
||||
|
||||
logger.info(f"Repository {repo.full_name} is private")
|
||||
return GitHubVisibility.PRIVATE
|
||||
|
||||
|
||||
def get_external_access_permission(
|
||||
repo: Repository, github_client: Github, add_prefix: bool = False
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access permission for a repository.
|
||||
Uses group-based permissions for efficiency and scalability.
|
||||
|
||||
add_prefix: When this method is called during the initial permission sync via the connector,
|
||||
the group ID isn't prefixed with the source while inserting the document record.
|
||||
So in that case, set add_prefix to True, allowing the method itself to handle
|
||||
prefixing. However, when the same method is invoked from doc_sync, our system
|
||||
already adds the prefix to the group ID while processing the ExternalAccess object.
|
||||
"""
|
||||
# We maintain collaborators, and outside collaborators as two separate groups
|
||||
# instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons:
|
||||
# 1. Changes in repo collaborators (additions/removals) would require updating all documents.
|
||||
# 2. Repo permissions can change without updating the repo's updated_at timestamp,
|
||||
# forcing full permission syncs for all documents every time, which is inefficient.
|
||||
|
||||
repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(
|
||||
f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}"
|
||||
)
|
||||
|
||||
if repo_visibility == GitHubVisibility.PUBLIC:
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is public - allowing access to all users"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
elif repo_visibility == GitHubVisibility.PRIVATE:
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is private - setting up restricted access"
|
||||
)
|
||||
|
||||
collaborators_group_id = form_collaborators_group_id(repo.id)
|
||||
outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id)
|
||||
if add_prefix:
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=collaborators_group_id,
|
||||
)
|
||||
outside_collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=outside_collaborators_group_id,
|
||||
)
|
||||
group_ids = {collaborators_group_id, outside_collaborators_group_id}
|
||||
|
||||
team_slugs = fetch_repository_team_slugs(repo, github_client)
|
||||
if add_prefix:
|
||||
team_slugs = [
|
||||
build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=slug,
|
||||
)
|
||||
for slug in team_slugs
|
||||
]
|
||||
group_ids.update(team_slugs)
|
||||
|
||||
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
# Internal repositories - accessible to organization members
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is internal - accessible to org members"
|
||||
)
|
||||
org_group_id = form_organization_group_id(repo.organization.id)
|
||||
if add_prefix:
|
||||
org_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=org_group_id,
|
||||
)
|
||||
group_ids = {org_group_id}
|
||||
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def get_external_user_group(
|
||||
repo: Repository, github_client: Github
|
||||
) -> list[ExternalUserGroup]:
|
||||
"""
|
||||
Get the external user group for a repository.
|
||||
Creates ExternalUserGroup objects with actual user emails for each permission group.
|
||||
"""
|
||||
repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(
|
||||
f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}"
|
||||
)
|
||||
|
||||
if repo_visibility == GitHubVisibility.PRIVATE:
|
||||
logger.info(f"Processing private repository {repo.full_name}")
|
||||
|
||||
collaborators, outside_collaborators = (
|
||||
_get_collaborators_and_outside_collaborators(github_client, repo)
|
||||
)
|
||||
teams = _fetch_repository_teams_detailed(repo, github_client)
|
||||
external_user_groups = []
|
||||
|
||||
user_emails = set()
|
||||
for collab in collaborators:
|
||||
if collab.email:
|
||||
user_emails.add(collab.email)
|
||||
else:
|
||||
logger.error(f"Collaborator {collab.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
collaborators_group = ExternalUserGroup(
|
||||
id=form_collaborators_group_id(repo.id),
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(collaborators_group)
|
||||
logger.info(f"Created collaborators group with {len(user_emails)} emails")
|
||||
|
||||
# Create group for outside collaborators
|
||||
user_emails = set()
|
||||
for collab in outside_collaborators:
|
||||
if collab.email:
|
||||
user_emails.add(collab.email)
|
||||
else:
|
||||
logger.error(f"Outside collaborator {collab.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
outside_collaborators_group = ExternalUserGroup(
|
||||
id=form_outside_collaborators_group_id(repo.id),
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(outside_collaborators_group)
|
||||
logger.info(
|
||||
f"Created outside collaborators group with {len(user_emails)} emails"
|
||||
)
|
||||
|
||||
# Create groups for teams
|
||||
for team in teams:
|
||||
user_emails = set()
|
||||
for member in team.members:
|
||||
if member.email:
|
||||
user_emails.add(member.email)
|
||||
else:
|
||||
logger.error(f"Team member {member.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
team_group = ExternalUserGroup(
|
||||
id=team.slug,
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(team_group)
|
||||
logger.info(
|
||||
f"Created team group {team.name} with {len(user_emails)} emails"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}"
|
||||
)
|
||||
return external_user_groups
|
||||
|
||||
if repo_visibility == GitHubVisibility.INTERNAL:
|
||||
logger.info(f"Processing internal repository {repo.full_name}")
|
||||
|
||||
org_group_id = form_organization_group_id(repo.organization.id)
|
||||
org_members = _fetch_organization_members(
|
||||
github_client, repo.organization.login
|
||||
)
|
||||
|
||||
user_emails = set()
|
||||
for member in org_members:
|
||||
if member.email:
|
||||
user_emails.add(member.email)
|
||||
else:
|
||||
logger.error(f"Org member {member.login} has no email")
|
||||
|
||||
org_group = ExternalUserGroup(
|
||||
id=org_group_id,
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
logger.info(
|
||||
f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}"
|
||||
)
|
||||
return [org_group]
|
||||
|
||||
logger.info(f"Repository {repo.full_name} is public - no user groups needed")
|
||||
return []
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
@@ -35,6 +36,7 @@ def _get_slim_doc_generator(
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
get_permissions_by_ids,
|
||||
)
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
@@ -40,8 +41,28 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
|
||||
def _merge_permissions_lists(
|
||||
permission_lists: list[list[GoogleDrivePermission]],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
"""
|
||||
Merge a list of permission lists into a single list of permissions.
|
||||
"""
|
||||
seen_permission_ids: set[str] = set()
|
||||
merged_permissions: list[GoogleDrivePermission] = []
|
||||
for permission_list in permission_lists:
|
||||
for permission in permission_list:
|
||||
if permission.id not in seen_permission_ids:
|
||||
merged_permissions.append(permission)
|
||||
seen_permission_ids.add(permission.id)
|
||||
|
||||
return merged_permissions
|
||||
|
||||
|
||||
def get_external_access_for_raw_gdrive_file(
|
||||
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
|
||||
file: GoogleDriveFileType,
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access for a raw Google Drive file.
|
||||
@@ -62,11 +83,28 @@ def get_external_access_for_raw_gdrive_file(
|
||||
GoogleDrivePermission.from_drive_permission(p) for p in permissions
|
||||
]
|
||||
elif permission_ids:
|
||||
permissions_list = get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
|
||||
def _get_permissions(
|
||||
drive_service: GoogleDriveService,
|
||||
) -> list[GoogleDrivePermission]:
|
||||
return get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
||||
permissions_list = _get_permissions(
|
||||
retriever_drive_service or admin_drive_service
|
||||
)
|
||||
if len(permissions_list) != len(permission_ids) and retriever_drive_service:
|
||||
logger.warning(
|
||||
f"Failed to get all permissions for file {doc_id} with retriever service, "
|
||||
"trying admin service"
|
||||
)
|
||||
backup_permissions_list = _get_permissions(admin_drive_service)
|
||||
permissions_list = _merge_permissions_lists(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
@@ -132,6 +170,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
|
||||
@@ -44,11 +44,17 @@ def _get_all_folders(
|
||||
|
||||
TODO: tweak things so we can fetch deltas.
|
||||
"""
|
||||
MAX_FAILED_PERCENTAGE = 0.5
|
||||
|
||||
all_folders: list[FolderInfo] = []
|
||||
seen_folder_ids: set[str] = set()
|
||||
|
||||
user_emails = google_drive_connector._get_all_user_emails()
|
||||
for user_email in user_emails:
|
||||
def _get_all_folders_for_user(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
skip_folders_without_permissions: bool,
|
||||
user_email: str,
|
||||
) -> None:
|
||||
"""Helper to get folders for a specific user + update shared seen_folder_ids"""
|
||||
drive_service = get_drive_service(
|
||||
google_drive_connector.creds,
|
||||
user_email,
|
||||
@@ -98,6 +104,20 @@ def _get_all_folders(
|
||||
)
|
||||
)
|
||||
|
||||
failed_count = 0
|
||||
user_emails = google_drive_connector._get_all_user_emails()
|
||||
for user_email in user_emails:
|
||||
try:
|
||||
_get_all_folders_for_user(
|
||||
google_drive_connector, skip_folders_without_permissions, user_email
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error getting folders for user {user_email}")
|
||||
failed_count += 1
|
||||
|
||||
if failed_count > MAX_FAILED_PERCENTAGE * len(user_emails):
|
||||
raise RuntimeError("Too many failed folder fetches during group sync")
|
||||
|
||||
return all_folders
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -17,6 +18,7 @@ JIRA_DOC_SYNC_TAG = "jira_doc_sync"
|
||||
def jira_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
jira_connector = JiraConnector(
|
||||
@@ -26,7 +28,7 @@ def jira_doc_sync(
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.JIRA,
|
||||
slim_connector=jira_connector,
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
@@ -15,14 +17,34 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class FetchAllDocumentsFunction(Protocol):
|
||||
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
|
||||
"""Protocol for a function that fetches documents for a connector credential pair.
|
||||
|
||||
def __call__(self) -> list[str]:
|
||||
This protocol defines the interface for functions that retrieve documents
|
||||
from the database, typically used in permission synchronization workflows.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sort_order: SortOrder | None,
|
||||
) -> list[DocumentRow]:
|
||||
"""
|
||||
Returns a list of document IDs for a connector credential pair.
|
||||
Fetches documents for a connector credential pair.
|
||||
"""
|
||||
...
|
||||
|
||||
This is typically used to determine which documents should no longer be
|
||||
accessible during the document sync process.
|
||||
|
||||
class FetchAllDocumentsIdsFunction(Protocol):
|
||||
"""Protocol for a function that fetches document IDs for a connector credential pair.
|
||||
|
||||
This protocol defines the interface for functions that retrieve document IDs
|
||||
from the database, typically used in permission synchronization workflows.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Fetches document IDs for a connector credential pair.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -32,6 +54,7 @@ DocSyncFuncType = Callable[
|
||||
[
|
||||
"ConnectorCredentialPair",
|
||||
FetchAllDocumentsFunction,
|
||||
FetchAllDocumentsIdsFunction,
|
||||
Optional["IndexingHeartbeatInterface"],
|
||||
],
|
||||
Generator["DocExternalAccess", None, None],
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Generator
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -130,6 +131,7 @@ def _get_slack_document_access(
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
|
||||
@@ -7,12 +7,16 @@ from pydantic import BaseModel
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.onyx.external_permissions.github.doc_sync import github_doc_sync
|
||||
from ee.onyx.external_permissions.github.group_sync import github_group_sync
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
@@ -20,6 +24,7 @@ from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
|
||||
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
@@ -63,6 +68,7 @@ class SyncConfig(BaseModel):
|
||||
def mock_doc_sync(
|
||||
cc_pair: "ConnectorCredentialPair",
|
||||
fetch_all_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: Optional["IndexingHeartbeatInterface"],
|
||||
) -> Generator["DocExternalAccess", None, None]:
|
||||
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
|
||||
@@ -117,6 +123,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
initial_index_should_sync=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.GITHUB: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=GITHUB_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=github_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=github_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.SALESFORCE: SyncConfig(
|
||||
censoring_config=CensoringConfig(
|
||||
chunk_censoring_func=censor_salesforce_chunks,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -18,6 +19,7 @@ TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
|
||||
def teams_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
teams_connector = TeamsConnector(
|
||||
@@ -27,7 +29,7 @@ def teams_doc_sync(
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.TEAMS,
|
||||
slim_connector=teams_connector,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -14,7 +14,7 @@ logger = setup_logger()
|
||||
|
||||
def generic_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnector,
|
||||
@@ -62,9 +62,9 @@ def generic_doc_sync(
|
||||
)
|
||||
|
||||
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
|
||||
existing_doc_ids = set(fetch_all_existing_docs_fn())
|
||||
existing_doc_ids: list[str] = fetch_all_existing_docs_ids_fn()
|
||||
|
||||
missing_doc_ids = existing_doc_ids - newly_fetched_doc_ids
|
||||
missing_doc_ids = set(existing_doc_ids) - newly_fetched_doc_ids
|
||||
|
||||
if not missing_doc_ids:
|
||||
return
|
||||
|
||||
@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
|
||||
def put_logo(
|
||||
file: UploadFile,
|
||||
is_logotype: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
|
||||
upload_logo(file=file, is_logotype=is_logotype)
|
||||
|
||||
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import IO
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def upload_logo(
|
||||
db_session: Session, file: UploadFile | str, is_logotype: bool = False
|
||||
) -> bool:
|
||||
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
|
||||
content: IO[Any]
|
||||
|
||||
if isinstance(file, str):
|
||||
@@ -129,7 +126,7 @@ def upload_logo(
|
||||
display_name = file.filename
|
||||
file_type = file.content_type or "image/jpeg"
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.save_file(
|
||||
content=content,
|
||||
display_name=display_name,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@@ -73,6 +74,7 @@ def _get_final_context_doc_indices(
|
||||
|
||||
def _convert_packet_stream_to_response(
|
||||
packets: ChatPacketStream,
|
||||
chat_session_id: UUID,
|
||||
) -> ChatBasicResponse:
|
||||
response = ChatBasicResponse()
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
@@ -216,6 +218,8 @@ def _convert_packet_stream_to_response(
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
response.chat_session_id = chat_session_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -237,13 +241,36 @@ def handle_simplified_chat_message(
|
||||
if not chat_message_req.message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
# Handle chat session creation if chat_session_id is not provided
|
||||
if chat_message_req.chat_session_id is None:
|
||||
if chat_message_req.persona_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either chat_session_id or persona_id must be provided",
|
||||
)
|
||||
|
||||
# Create a new chat session with the provided persona_id
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave empty for simple API
|
||||
user_id=user.id if user else None,
|
||||
persona_id=chat_message_req.persona_id,
|
||||
)
|
||||
chat_session_id = new_chat_session.id
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
else:
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message, _ = create_chat_chain(
|
||||
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
except Exception:
|
||||
parent_message = get_or_create_root_message(
|
||||
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -258,7 +285,7 @@ def handle_simplified_chat_message(
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_message_req.chat_session_id,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
@@ -283,7 +310,7 @@ def handle_simplified_chat_message(
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return _convert_packet_stream_to_response(packets)
|
||||
return _convert_packet_stream_to_response(packets, chat_session_id)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
@@ -403,4 +430,4 @@ def handle_send_message_simple_with_history(
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return _convert_packet_stream_to_response(packets)
|
||||
return _convert_packet_stream_to_response(packets, chat_session.id)
|
||||
|
||||
@@ -41,11 +41,13 @@ class DocumentSearchRequest(ChunkContext):
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
"""Before creating messages, be sure to create a chat_session and get an id
|
||||
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
|
||||
Note, for simplicity this option only allows for a single linear chain of messages
|
||||
"""
|
||||
|
||||
chat_session_id: UUID
|
||||
chat_session_id: UUID | None = None
|
||||
# Optional persona_id to create a new chat session if chat_session_id is not provided
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
@@ -62,6 +64,12 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
raise ValueError("Either chat_session_id or persona_id must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
@@ -171,6 +179,9 @@ class ChatBasicResponse(BaseModel):
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
# Chat session ID for tracking conversation continuity
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
|
||||
@@ -358,7 +358,7 @@ def get_query_history_export_status(
|
||||
# If task is None, then it's possible that the task has already finished processing.
|
||||
# Therefore, we should then check if the export file has already been stored inside of the file-store.
|
||||
# If that *also* doesn't exist, then we can return a 404.
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
has_file = file_store.has_file(
|
||||
@@ -385,7 +385,7 @@ def download_query_history_csv(
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
has_file = file_store.has_file(
|
||||
file_id=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
|
||||
@@ -53,7 +53,7 @@ def read_usage_report(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
try:
|
||||
file = get_usage_report_data(db_session, report_name)
|
||||
file = get_usage_report_data(report_name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ def create_new_usage_report(
|
||||
period: tuple[datetime, datetime] | None,
|
||||
) -> UsageReportMetadata:
|
||||
report_id = str(uuid.uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
messages_file_id = generate_chat_messages_report(
|
||||
db_session, file_store, report_id, period
|
||||
|
||||
@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
|
||||
store_ee_settings(final_enterprise_settings)
|
||||
|
||||
|
||||
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
|
||||
def _seed_logo(logo_path: str | None) -> None:
|
||||
if logo_path:
|
||||
logger.notice("Uploading logo")
|
||||
upload_logo(db_session=db_session, file=logo_path)
|
||||
upload_logo(file=logo_path)
|
||||
|
||||
|
||||
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
|
||||
@@ -245,7 +245,7 @@ def seed_db() -> None:
|
||||
if seed_config.custom_tools is not None:
|
||||
_seed_custom_tools(db_session, seed_config.custom_tools)
|
||||
|
||||
_seed_logo(db_session, seed_config.seeded_logo_path)
|
||||
_seed_logo(seed_config.seeded_logo_path)
|
||||
_seed_enterprise_settings(seed_config)
|
||||
_seed_analytics_script(seed_config)
|
||||
|
||||
|
||||
@@ -10,10 +10,12 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -47,6 +49,26 @@ def gate_product(
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.post("/product-gating/full-sync")
|
||||
def gate_product_full_sync(
|
||||
product_gating_request: ProductGatingFullSyncRequest,
|
||||
_: None = Depends(control_plane_dep),
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Bulk operation to overwrite the entire gated tenant set.
|
||||
This replaces all currently gated tenants with the provided list.
|
||||
Gated tenants are not available to access the product and will be
|
||||
directed to the billing page when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
overwrite_full_gated_set(product_gating_request.gated_tenant_ids)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate products during full sync")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
|
||||
@@ -19,6 +19,10 @@ class ProductGatingRequest(BaseModel):
|
||||
application_status: ApplicationStatus
|
||||
|
||||
|
||||
class ProductGatingFullSyncRequest(BaseModel):
|
||||
gated_tenant_ids: list[str]
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
subscribed: bool
|
||||
|
||||
|
||||
@@ -16,10 +16,6 @@ logger = setup_logger()
|
||||
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
# Store the full status
|
||||
status_key = f"tenant:{tenant_id}:status"
|
||||
redis_client.set(status_key, status.value)
|
||||
|
||||
# Maintain the GATED_ACCESS set
|
||||
if status == ApplicationStatus.GATED_ACCESS:
|
||||
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
|
||||
@@ -46,6 +42,25 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
|
||||
raise
|
||||
|
||||
|
||||
def overwrite_full_gated_set(tenant_ids: list[str]) -> None:
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
pipeline = redis_client.pipeline()
|
||||
|
||||
# using pipeline doesn't automatically add the tenant_id prefix
|
||||
full_gated_set_key = f"{ONYX_CLOUD_TENANT_ID}:{GATED_TENANTS_KEY}"
|
||||
|
||||
# Clear the existing set
|
||||
pipeline.delete(full_gated_set_key)
|
||||
|
||||
# Add all tenant IDs to the set and set their status
|
||||
for tenant_id in tenant_ids:
|
||||
pipeline.sadd(full_gated_set_key, tenant_id)
|
||||
|
||||
# Execute all commands at once
|
||||
pipeline.execute()
|
||||
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
|
||||
@@ -203,6 +203,8 @@ def generate_simple_sql(
|
||||
if state.kg_entity_temp_view_name is None:
|
||||
raise ValueError("kg_entity_temp_view_name is not set")
|
||||
|
||||
sql_statement_display: str | None = None
|
||||
|
||||
## STEP 3 - articulate goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
@@ -381,7 +383,18 @@ def generate_simple_sql(
|
||||
|
||||
raise e
|
||||
|
||||
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
|
||||
# display sql statement with view names replaced by general view names
|
||||
sql_statement_display = sql_statement.replace(
|
||||
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
|
||||
)
|
||||
sql_statement_display = sql_statement_display.replace(
|
||||
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
|
||||
)
|
||||
sql_statement_display = sql_statement_display.replace(
|
||||
state.kg_entity_temp_view_name, "<your_entity_view_name>"
|
||||
)
|
||||
|
||||
logger.debug(f"A3 - sql_statement after correction: {sql_statement_display}")
|
||||
|
||||
# Get SQL for source documents
|
||||
|
||||
@@ -409,7 +422,20 @@ def generate_simple_sql(
|
||||
"relationship_table", rel_temp_view
|
||||
)
|
||||
|
||||
logger.debug(f"A3 source_documents_sql: {source_documents_sql}")
|
||||
if source_documents_sql:
|
||||
source_documents_sql_display = source_documents_sql.replace(
|
||||
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
|
||||
)
|
||||
source_documents_sql_display = source_documents_sql_display.replace(
|
||||
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
|
||||
)
|
||||
source_documents_sql_display = source_documents_sql_display.replace(
|
||||
state.kg_entity_temp_view_name, "<your_entity_view_name>"
|
||||
)
|
||||
else:
|
||||
source_documents_sql_display = "(No source documents SQL generated)"
|
||||
|
||||
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
|
||||
|
||||
scalar_result = None
|
||||
query_results = None
|
||||
@@ -435,7 +461,13 @@ def generate_simple_sql(
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
# TODO: raise error on frontend
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
@@ -459,8 +491,14 @@ def generate_simple_sql(
|
||||
for source_document_result in query_source_document_results
|
||||
]
|
||||
except Exception as e:
|
||||
# No stopping here, the individualized SQL query is not mandatory
|
||||
# TODO: raise error on frontend
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
|
||||
else:
|
||||
@@ -493,11 +531,11 @@ def generate_simple_sql(
|
||||
if reasoning:
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
|
||||
|
||||
if main_sql_statement:
|
||||
if sql_statement_display:
|
||||
stream_write_step_answer_explicit(
|
||||
writer,
|
||||
step_nr=_KG_STEP_NR,
|
||||
answer=f" \n Generated SQL: {main_sql_statement}",
|
||||
answer=f" \n Generated SQL: {sql_statement_display}",
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
@@ -51,7 +51,6 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
@@ -128,6 +127,7 @@ def choose_tool(
|
||||
override_kwargs: SearchToolOverrideKwargs = (
|
||||
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
|
||||
)
|
||||
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
@@ -174,7 +174,6 @@ def get_test_config(
|
||||
# The docs retrieved by this flow are already relevance-filtered
|
||||
all_docs_useful=True
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
@@ -198,7 +197,7 @@ def get_test_config(
|
||||
prompt_config=prompt_config,
|
||||
llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
document_pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
|
||||
@@ -24,13 +24,14 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatt
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -39,6 +40,7 @@ from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -92,7 +94,13 @@ def on_task_prerun(
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
# Reset any per-task logging context so that prefixes (e.g. pruning_ctx)
|
||||
# from a previous task executed in the same worker process do not leak
|
||||
# into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt]
|
||||
# prefixes observed when a pruning task finishes and an indexing task
|
||||
# runs in the same process.
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
|
||||
def on_task_postrun(
|
||||
@@ -145,8 +153,11 @@ def on_task_postrun(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions to
|
||||
# do these things going forward. In short, things should generally be like the doc
|
||||
# sync task rather than the others below
|
||||
if task_id.startswith(DOCUMENT_SYNC_PREFIX):
|
||||
r.srem(DOCUMENT_SYNC_TASKSET_KEY, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
@@ -470,7 +481,8 @@ class TenantContextFilter(logging.Filter):
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id:
|
||||
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
|
||||
# Match the 8 character tenant abbreviation used in OnyxLoggingAdapter
|
||||
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8]
|
||||
record.name = f"[t:{tenant_id}]"
|
||||
else:
|
||||
record.name = ""
|
||||
|
||||
102
backend/onyx/background/celery/apps/docfetching.py
Normal file
102
backend/onyx/background/celery/apps/docfetching.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
@@ -12,7 +12,7 @@ from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -21,7 +21,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.indexing")
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME)
|
||||
|
||||
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
@@ -108,6 +108,6 @@ for bootstep in base_bootsteps:
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
@@ -116,6 +116,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.result import AsyncResult
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@@ -18,9 +19,7 @@ from redis.lock import Lock as RedisLock
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.tasks.indexing.utils import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
|
||||
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -29,9 +28,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -156,7 +153,10 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
|
||||
RedisGlobalConnectorCredentialPair.reset_all(r)
|
||||
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions
|
||||
# This is the preferred way to do this going forward
|
||||
reset_document_sync(r)
|
||||
|
||||
RedisDocumentSet.reset_all(r)
|
||||
RedisUserGroup.reset_all(r)
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
@@ -167,24 +167,50 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
# This uses database coordination instead of Redis fencing
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
# Get potentially orphaned attempts (those with active status and task IDs)
|
||||
potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids(
|
||||
db_session
|
||||
)
|
||||
|
||||
for attempt_id in potentially_orphaned_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
|
||||
# handle case where not started or docfetching is done but indexing is not
|
||||
if (
|
||||
not attempt
|
||||
or not attempt.celery_task_id
|
||||
or attempt.total_batches is not None
|
||||
):
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Canceling leftover index attempt found on startup: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
logger.exception(
|
||||
f"Marking attempt {attempt.id} as canceled due to validation error 2"
|
||||
)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
# Check if the Celery task actually exists
|
||||
try:
|
||||
|
||||
result: AsyncResult = AsyncResult(attempt.celery_task_id)
|
||||
|
||||
# If the task is not in PENDING state, it exists in Celery
|
||||
if result.state != "PENDING":
|
||||
continue
|
||||
|
||||
# Task is orphaned - mark as failed
|
||||
failure_reason = (
|
||||
f"Orphaned index attempt found on startup - Celery task not found: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"celery_task_id={attempt.celery_task_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
|
||||
except Exception:
|
||||
# If we can't check the task status, be conservative and continue
|
||||
logger.warning(
|
||||
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
|
||||
f"task_id={attempt.celery_task_id}"
|
||||
)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
@@ -291,7 +317,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -26,7 +26,7 @@ def celery_get_unacked_length(r: Redis) -> int:
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
Unacked entries belonging to the indexing queues are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
22
backend/onyx/background/celery/configs/docfetching.py
Normal file
22
backend/onyx/background/celery/configs/docfetching.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Docfetching worker configuration
|
||||
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,5 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -24,6 +24,6 @@ task_acks_late = shared_config.task_acks_late
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -100,24 +100,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
|
||||
@@ -40,9 +40,11 @@ from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -69,13 +71,21 @@ def revoke_tasks_blocking_deletion(
|
||||
) -> None:
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
try:
|
||||
index_payload = redis_connector_index.payload
|
||||
if index_payload and index_payload.celery_task_id:
|
||||
app.control.revoke(index_payload.celery_task_id)
|
||||
recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=redis_connector.cc_pair_id,
|
||||
search_settings_id=search_settings.id,
|
||||
limit=1,
|
||||
db_session=db_session,
|
||||
)
|
||||
if (
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
and recent_index_attempts[0].celery_task_id
|
||||
):
|
||||
app.control.revoke(recent_index_attempts[0].celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked indexing task {index_payload.celery_task_id}."
|
||||
f"Revoked indexing task {recent_index_attempts[0].celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking indexing task")
|
||||
@@ -281,8 +291,16 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings.id,
|
||||
limit=1,
|
||||
db_session=db_session,
|
||||
)
|
||||
if (
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
|
||||
642
backend/onyx/background/celery/tasks/docfetching/tasks.py
Normal file
642
backend/onyx/background/celery/tasks/docfetching/tasks.py
Normal file
@@ -0,0 +1,642 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.models import DocProcessingContext
|
||||
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
|
||||
from onyx.background.celery.tasks.models import SimpleJobResult
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _verify_indexing_attempt(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the indexing attempt exists and is in the correct state.
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
)
|
||||
|
||||
if attempt.connector_credential_pair_id != cc_pair_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - CC pair mismatch: "
|
||||
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
if attempt.search_settings_id != search_settings_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Search settings mismatch: "
|
||||
f"expected={search_settings_id} actual={attempt.search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
if attempt.status not in [
|
||||
IndexingStatus.NOT_STARTED,
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
]:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Invalid attempt status: "
|
||||
f"attempt_id={index_attempt_id} status={attempt.status}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
# Check for cancellation
|
||||
if IndexingCoordination.check_cancellation_requested(
|
||||
db_session, index_attempt_id
|
||||
):
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"docfetching_task - IndexAttempt verified: "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
|
||||
def docfetching_task(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This function is run in a SimpleJob as a new process. It is responsible for validating
|
||||
some stuff, but basically it just calls run_indexing_entrypoint.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
|
||||
# Start heartbeat for this indexing attempt
|
||||
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
|
||||
try:
|
||||
_docfetching_task(
|
||||
app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id
|
||||
)
|
||||
finally:
|
||||
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
|
||||
|
||||
|
||||
def _docfetching_task(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector.new_index(search_settings_id)
|
||||
|
||||
# TODO: remove all fences, cause all signals to be set in postgres
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise SimpleJobException(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
)
|
||||
|
||||
# Verify the indexing attempt exists and is valid
|
||||
# This replaces the Redis fence payload waiting
|
||||
_verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise SimpleJobException(
|
||||
f"cc_pair not found: cc_pair={cc_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
redis_connector,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
app,
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
except ConnectorValidationError:
|
||||
raise SimpleJobException(
|
||||
f"Indexing task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# special bulletproofing ... truncate long exception messages
|
||||
# for exception types that require more args, this will fail
|
||||
# thus the try/except
|
||||
try:
|
||||
sanitized_e = type(e)(str(e)[:1024])
|
||||
sanitized_e.__traceback__ = e.__traceback__
|
||||
raise sanitized_e
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
os._exit(0) # ensure process exits cleanly
|
||||
|
||||
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
result.connector_source = connector_source
|
||||
|
||||
if job.process:
|
||||
result.exit_code = job.process.exitcode
|
||||
|
||||
if job.status != "error":
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
return result
|
||||
|
||||
ignore_exitcode = False
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...",
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
else:
|
||||
if result.exit_code is not None:
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
result.exception_str = job.exception()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
)
|
||||
def docfetching_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
|
||||
docfetching and docprocessing.
|
||||
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
|
||||
|
||||
This task spawns a new process for a new scheduled index attempt. That
|
||||
new process (which runs the docfetching_task function) does the following:
|
||||
|
||||
1) determines parameters of the indexing attempt (which connector indexing function to run,
|
||||
start and end time, from prev checkpoint or not), then run that connector. Specifically,
|
||||
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
|
||||
At the moment these two steps (reading external data and converting to an Onyx document)
|
||||
are not parallelized in most connectors; that's a subject for future work.
|
||||
|
||||
Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned
|
||||
to process it. docprocessing involves the steps listed below.
|
||||
|
||||
2) upserts documents to postgres (index_doc_batch_prepare)
|
||||
3) chunks each document (optionally adds context for contextual rag)
|
||||
4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
|
||||
5) write chunks to vespa (write_chunks_to_vector_db_with_backoff)
|
||||
6) update document and indexing metadata in postgres
|
||||
7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes
|
||||
all locally stored IDs missing from the most recently pulled document ID list
|
||||
|
||||
Some important notes:
|
||||
Invariants:
|
||||
- docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog.
|
||||
The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed
|
||||
if it is not making progress.
|
||||
- All docprocessing tasks are spawned by a docfetching task.
|
||||
- all docfetching tasks, docprocessing tasks, and document batches in the file store are
|
||||
associated with a specific index attempt.
|
||||
- the index attempt status is the source of truth for what is currently happening with the index attempt.
|
||||
It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible.
|
||||
|
||||
How we deal with failures/ partial indexing:
|
||||
- non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run
|
||||
- checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run
|
||||
|
||||
Misc:
|
||||
- most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it
|
||||
- Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness
|
||||
- progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed
|
||||
- TODO: task level timeouts (i.e. a connector stuck in an infinite loop)
|
||||
|
||||
|
||||
Comments below are from the old version and some may no longer be valid.
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
|
||||
Some more Richard notes:
|
||||
celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
|
||||
To work around this, we use pool=threads and proxy our work to a spawned task.
|
||||
|
||||
acks_late must be set to False. Otherwise, celery's visibility timeout will
|
||||
cause any task that runs longer than the timeout to be redispatched by the broker.
|
||||
There appears to be no good workaround for this, so we need to handle redispatching
|
||||
manually.
|
||||
NOTE: we try/except all db access in this function because as a watchdog, this function
|
||||
needs to be extremely stable.
|
||||
"""
|
||||
# TODO: remove dependence on Redis
|
||||
start = time.monotonic()
|
||||
|
||||
result = SimpleJobResult()
|
||||
|
||||
ctx = DocProcessingContext(
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
log_builder = ConnectorIndexingLogBuilder(ctx)
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - starting",
|
||||
mp_start_method=str(multiprocessing.get_start_method()),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
|
||||
|
||||
job = client.submit(
|
||||
docfetching_task,
|
||||
self.app,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
global_version.is_ee_version(),
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
if not job or not job.process:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the process has moved out of the starting state
|
||||
num_waits = 0
|
||||
while True:
|
||||
if num_waits > 15:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
job.release()
|
||||
return
|
||||
|
||||
if job.process.is_alive() or job.process.exitcode is not None:
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
num_waits += 1
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawn succeeded",
|
||||
pid=str(job.process.pid),
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError("Index attempt not found")
|
||||
|
||||
result.connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
time.monotonic()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task exceptioned"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
job.release()
|
||||
break
|
||||
|
||||
# log the memory usage for tracking down memory leaks / connector-specific memory issues
|
||||
pid = job.process.pid
|
||||
if pid is not None:
|
||||
# Only emit memory info once per minute (60 seconds)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_memory_emit_time >= 60.0:
|
||||
emit_process_memory(
|
||||
pid,
|
||||
"indexing_worker",
|
||||
{
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"search_settings_id": search_settings_id,
|
||||
"index_attempt_id": index_attempt_id,
|
||||
},
|
||||
)
|
||||
last_memory_emit_time = current_time
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# No need to expose full stack trace for validation errors
|
||||
result.exception_str = str(e)
|
||||
else:
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, ctx.index_attempt_id)
|
||||
|
||||
# only mark failures if not already terminal,
|
||||
# otherwise we're overwriting potential real stack traces
|
||||
if attempt and not attempt.status.is_terminal():
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
|
||||
normalized_exception_str = "None"
|
||||
if result.exception_str:
|
||||
normalized_exception_str = result.exception_str.replace(
|
||||
"\n", "\\n"
|
||||
).replace('"', '\\"')
|
||||
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=result.status.value,
|
||||
exit_code=str(result.exit_code),
|
||||
exception=f'"{normalized_exception_str}"',
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
|
||||
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
)
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Indexing watchdog - activity timeout exceeded: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
job.cancel()
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
import threading
|
||||
|
||||
from sqlalchemy import update
|
||||
|
||||
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import IndexAttempt
|
||||
|
||||
|
||||
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
|
||||
"""Start a heartbeat thread for the given index attempt"""
|
||||
stop_event = threading.Event()
|
||||
|
||||
def heartbeat_loop() -> None:
|
||||
while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL):
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.execute(
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
# Silently continue if heartbeat fails
|
||||
pass
|
||||
|
||||
thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
thread.start()
|
||||
return thread, stop_event
|
||||
|
||||
|
||||
def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None:
|
||||
"""Stop the heartbeat thread"""
|
||||
stop_event.set()
|
||||
thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown
|
||||
1247
backend/onyx/background/celery/tasks/docprocessing/tasks.py
Normal file
1247
backend/onyx/background/celery/tasks/docprocessing/tasks.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,8 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
@@ -12,8 +10,6 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
@@ -21,27 +17,19 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
from onyx.db.index_attempt import delete_index_attempt
|
||||
from onyx.db.index_attempt import get_all_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -50,54 +38,6 @@ logger = setup_logger()
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
@@ -123,10 +63,9 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
"""Amount isn't used yet."""
|
||||
@@ -171,186 +110,28 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
raise
|
||||
|
||||
|
||||
class IndexingCallback(IndexingCallbackBase):
|
||||
# NOTE: we're in the process of removing all fences from indexing; this will
|
||||
# eventually no longer be used. For now, it is used only for connector pausing.
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
):
|
||||
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
|
||||
self.redis_connector = redis_connector
|
||||
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
def should_stop(self) -> bool:
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
|
||||
# included to satisfy old interface
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector_index.set_active()
|
||||
self.redis_connector_index.set_connector_active()
|
||||
super().progress(tag, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
"""Validates all indexing fences for this tenant ... aka makes sure
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
# NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed
|
||||
# as they are no longer needed with database-based coordination. The new validation is
|
||||
# handled by validate_active_indexing_attempts in the main indexing tasks module.
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
@@ -414,10 +195,12 @@ def should_index(
|
||||
)
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
task_logger.info(
|
||||
f"_should_index: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"refresh_freq={connector.refresh_freq}"
|
||||
)
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
@@ -517,7 +300,7 @@ def should_index(
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
@@ -531,10 +314,11 @@ def try_creating_indexing_task(
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
@@ -547,61 +331,42 @@ def try_creating_indexing_task(
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
redis_connector_index: RedisConnectorIndex
|
||||
index_attempt_id = None
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexPayload(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_INDEXING
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -613,14 +378,18 @@ def try_creating_indexing_task(
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
@@ -628,9 +397,10 @@ def try_creating_indexing_task(
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
File diff suppressed because it is too large
Load Diff
110
backend/onyx/background/celery/tasks/models.py
Normal file
110
backend/onyx/background/celery/tasks/models.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DocProcessingContext(BaseModel):
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
|
||||
|
||||
class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
SPAWN_NOT_ALIVE = (
|
||||
"spawn_not_alive" # spawn succeeded but process did not come alive
|
||||
)
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
|
||||
FENCE_READINESS_TIMEOUT = (
|
||||
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
|
||||
)
|
||||
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
|
||||
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
|
||||
INDEX_ATTEMPT_MISMATCH = (
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR = (
|
||||
"connector_validation_error" # the connector validation failed
|
||||
)
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
# the watchdog received a termination signal
|
||||
TERMINATED_BY_SIGNAL = "terminated_by_signal"
|
||||
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
|
||||
# consolidate once we know more
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
|
||||
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
|
||||
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
|
||||
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
|
||||
}
|
||||
|
||||
return _ENUM_TO_CODE[self]
|
||||
|
||||
@classmethod
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
|
||||
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
|
||||
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
|
||||
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
|
||||
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
|
||||
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
|
||||
}
|
||||
|
||||
if code in _CODE_TO_ENUM:
|
||||
return _CODE_TO_ENUM[code]
|
||||
|
||||
return IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
|
||||
|
||||
class SimpleJobResult:
|
||||
"""The data we want to have when the watchdog finishes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
self.connector_source = None
|
||||
self.exit_code = None
|
||||
self.exception_str = None
|
||||
|
||||
status: IndexingWatchdogTerminalStatus
|
||||
connector_source: str | None
|
||||
exit_code: int | None
|
||||
exception_str: str | None
|
||||
@@ -147,7 +147,7 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"indexing_queue_length": "indexing",
|
||||
"docprocessing_queue_length": "docprocessing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
n_indexing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.DOCPROCESSING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(n_indexing_prefetched)} "
|
||||
f"docfetching={n_docfetching} "
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -70,9 +70,9 @@ logger = setup_logger()
|
||||
def _get_pruning_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the pruning block signal.
|
||||
Base expiration is 3600 seconds (1 hour), multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
Base expiration is 60 seconds (1 minute), multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
"""
|
||||
base_expiration = 3600 # seconds
|
||||
base_expiration = 60 # seconds
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return base_expiration
|
||||
@@ -138,14 +138,14 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
# if never pruned, use the connector creation time. We could also
|
||||
# compute the completion time of the first successful index attempt, but
|
||||
# that is a reasonably heavy operation. This is a reasonable approximation —
|
||||
# in the worst case, we'll prune a little bit earlier than we should.
|
||||
last_pruned = cc_pair.connector.time_created
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
return datetime.now(timezone.utc) >= next_prune
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -173,6 +173,9 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
|
||||
task_logger.info("Checking for pruning due")
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
@@ -187,15 +190,18 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
logger.error(f"CC pair not found: {cc_pair_id}")
|
||||
continue
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
logger.info(f"CC pair not due for pruning: {cc_pair_id}")
|
||||
continue
|
||||
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
logger.info(f"Pruning not created: {cc_pair_id}")
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
@@ -264,11 +270,16 @@ def try_creating_prune_generator_task(
|
||||
is used to trigger prunes immediately, e.g. via the web ui.
|
||||
"""
|
||||
|
||||
logger.info(f"try_creating_prune_generator_task: cc_pair={cc_pair.id}")
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
count = redis_connector.prune.get_active_task_count()
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} no simultaneous pruning allowed"
|
||||
)
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -282,23 +293,38 @@ def try_creating_prune_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} lock not acquired"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
# skip pruning if already pruning
|
||||
if redis_connector.prune.fenced:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} already pruning"
|
||||
)
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting"
|
||||
)
|
||||
return None
|
||||
|
||||
# skip pruning if doc permissions sync is running
|
||||
if redis_connector.permissions.fenced:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} permissions sync running"
|
||||
)
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting"
|
||||
)
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
@@ -441,7 +467,7 @@ def connector_pruning_generator_task(
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
178
backend/onyx/background/celery/tasks/vespa/document_sync.py
Normal file
178
backend/onyx/background/celery/tasks/vespa/document_sync.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.document import construct_document_id_select_by_needs_sync
|
||||
from onyx.db.document import count_documents_by_needs_sync
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# Redis keys for document sync tracking
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_document_sync_fenced(r: Redis) -> bool:
|
||||
"""Check if document sync tasks are currently in progress."""
|
||||
return bool(r.exists(DOCUMENT_SYNC_FENCE_KEY))
|
||||
|
||||
|
||||
def get_document_sync_payload(r: Redis) -> int | None:
|
||||
"""Get the initial number of tasks that were created."""
|
||||
bytes_result = r.get(DOCUMENT_SYNC_FENCE_KEY)
|
||||
if bytes_result is None:
|
||||
return None
|
||||
return int(cast(int, bytes_result))
|
||||
|
||||
|
||||
def get_document_sync_remaining(r: Redis) -> int:
|
||||
"""Get the number of tasks still pending completion."""
|
||||
return cast(int, r.scard(DOCUMENT_SYNC_TASKSET_KEY))
|
||||
|
||||
|
||||
def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
"""Set up the fence and register with active fences."""
|
||||
if payload is None:
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
def delete_document_sync_taskset(r: Redis) -> None:
|
||||
"""Clear the document sync taskset."""
|
||||
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
|
||||
|
||||
|
||||
def reset_document_sync(r: Redis) -> None:
|
||||
"""Reset all document sync tracking data."""
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
def generate_document_sync_tasks(
|
||||
r: Redis,
|
||||
max_tasks: int,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int]:
|
||||
"""Generate sync tasks for all documents that need syncing.
|
||||
|
||||
Args:
|
||||
r: Redis client
|
||||
max_tasks: Maximum number of tasks to generate
|
||||
celery_app: Celery application instance
|
||||
db_session: Database session
|
||||
lock: Redis lock for coordination
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (tasks_generated, total_docs_found)
|
||||
"""
|
||||
last_lock_time = time.monotonic()
|
||||
num_tasks_sent = 0
|
||||
num_docs = 0
|
||||
|
||||
# Get all documents that need syncing
|
||||
stmt = construct_document_id_select_by_needs_sync()
|
||||
|
||||
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
|
||||
doc_id = cast(str, doc_id)
|
||||
current_time = time.monotonic()
|
||||
|
||||
# Reacquire lock periodically to prevent timeout
|
||||
if current_time - last_lock_time >= (CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
num_docs += 1
|
||||
|
||||
# Create a unique task ID
|
||||
custom_task_id = f"{DOCUMENT_SYNC_PREFIX}_{uuid4()}"
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
|
||||
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_tasks_sent += 1
|
||||
|
||||
if num_tasks_sent >= max_tasks:
|
||||
break
|
||||
|
||||
return num_tasks_sent, num_docs
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
max_tasks: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if is_document_sync_fenced(r):
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
logger.info("No stale documents found. Skipping sync tasks generation.")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
|
||||
)
|
||||
|
||||
logger.info("generate_document_sync_tasks starting for all documents.")
|
||||
|
||||
# Generate all tasks in one pass
|
||||
result = generate_document_sync_tasks(
|
||||
r, max_tasks, celery_app, db_session, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
tasks_generated, total_docs = result
|
||||
|
||||
if tasks_generated >= max_tasks:
|
||||
logger.info(
|
||||
f"generate_document_sync_tasks reached the task generation limit: "
|
||||
f"tasks_generated={tasks_generated} max_tasks={max_tasks}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"generate_document_sync_tasks finished for all documents. "
|
||||
f"tasks_generated={tasks_generated} total_docs_found={total_docs}"
|
||||
)
|
||||
|
||||
set_document_sync_fence(r, tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -20,14 +20,19 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_FENCE_KEY
|
||||
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_payload
|
||||
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_remaining
|
||||
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
|
||||
from onyx.background.celery.tasks.vespa.document_sync import (
|
||||
try_generate_stale_document_sync_tasks,
|
||||
)
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import count_documents_by_needs_sync
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import delete_document_set
|
||||
@@ -47,10 +52,6 @@ from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
@@ -166,8 +167,11 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
# NOTE: removing the "Redis*" classes, prefer to just have functions to
|
||||
# do these things going forward. In short, things should generally be like the doc
|
||||
# sync task rather than the others
|
||||
if key_str == DOCUMENT_SYNC_FENCE_KEY:
|
||||
monitor_document_sync_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
@@ -203,82 +207,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
max_tasks: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
|
||||
if redis_global_ccpair.fenced:
|
||||
return None
|
||||
|
||||
redis_global_ccpair.delete_taskset()
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"RedisConnector.generate_tasks starting by cc_pair. "
|
||||
"Documents spanning multiple cc_pairs will only be synced once."
|
||||
)
|
||||
|
||||
docs_to_skip: set[str] = set()
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
tasks_remaining = max_tasks
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
rc.set_skip_docs(docs_to_skip)
|
||||
result = rc.generate_tasks(
|
||||
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
if result[1] == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
|
||||
)
|
||||
|
||||
total_tasks_generated += result[0]
|
||||
tasks_remaining -= result[0]
|
||||
if tasks_remaining <= 0:
|
||||
break
|
||||
|
||||
if tasks_remaining <= 0:
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks reached the task generation limit: "
|
||||
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
redis_global_ccpair.set_fence(total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
celery_app: Celery,
|
||||
document_set_id: int,
|
||||
@@ -433,19 +361,18 @@ def try_generate_user_group_sync_tasks(
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
|
||||
initial_count = redis_global_ccpair.payload
|
||||
def monitor_document_sync_taskset(r: Redis) -> None:
|
||||
initial_count = get_document_sync_payload(r)
|
||||
if initial_count is None:
|
||||
return
|
||||
|
||||
remaining = redis_global_ccpair.get_remaining()
|
||||
remaining = get_document_sync_remaining(r)
|
||||
task_logger.info(
|
||||
f"Stale document sync progress: remaining={remaining} initial={initial_count}"
|
||||
f"Document sync progress: remaining={remaining} initial={initial_count}"
|
||||
)
|
||||
if remaining == 0:
|
||||
redis_global_ccpair.reset()
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
reset_document_sync(r)
|
||||
task_logger.info(f"Successfully synced all documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
|
||||
@@ -10,7 +10,7 @@ set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.indexing import celery_app
|
||||
from onyx.background.celery.apps.docfetching import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.docprocessing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -33,7 +33,7 @@ def save_checkpoint(
|
||||
"""Save a checkpoint for a given index attempt to the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.save_file(
|
||||
content=BytesIO(checkpoint.model_dump_json().encode()),
|
||||
display_name=checkpoint_pointer,
|
||||
@@ -52,11 +52,11 @@ def save_checkpoint(
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, connector: BaseConnector
|
||||
index_attempt_id: int, connector: BaseConnector
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
if isinstance(connector, CheckpointedConnector):
|
||||
@@ -71,7 +71,7 @@ def get_latest_valid_checkpoint(
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
connector: BaseConnector,
|
||||
) -> ConnectorCheckpoint:
|
||||
) -> tuple[ConnectorCheckpoint, bool]:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -83,7 +83,7 @@ def get_latest_valid_checkpoint(
|
||||
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
|
||||
# where we make no progress. Only do this if we have had at least
|
||||
# _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts.
|
||||
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
had_any_progress = False
|
||||
for candidate in checkpoint_candidates:
|
||||
if (
|
||||
@@ -99,7 +99,7 @@ def get_latest_valid_checkpoint(
|
||||
f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return connector.build_dummy_checkpoint()
|
||||
return connector.build_dummy_checkpoint(), False
|
||||
|
||||
# filter out any candidates that don't meet the criteria
|
||||
checkpoint_candidates = [
|
||||
@@ -140,11 +140,10 @@ def get_latest_valid_checkpoint(
|
||||
logger.info(
|
||||
f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch."
|
||||
)
|
||||
return checkpoint
|
||||
return checkpoint, False
|
||||
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
connector=connector,
|
||||
)
|
||||
@@ -153,14 +152,14 @@ def get_latest_valid_checkpoint(
|
||||
f"Failed to load checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint."
|
||||
)
|
||||
return checkpoint
|
||||
return checkpoint, False
|
||||
|
||||
logger.info(
|
||||
f"Using checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
|
||||
f"{previous_checkpoint}"
|
||||
)
|
||||
return previous_checkpoint
|
||||
return previous_checkpoint, True
|
||||
|
||||
|
||||
def get_index_attempts_with_old_checkpoints(
|
||||
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
if not index_attempt.checkpoint_pointer:
|
||||
return None
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.delete_file(index_attempt.checkpoint_pointer)
|
||||
|
||||
index_attempt.checkpoint_pointer = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
@@ -5,7 +6,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import source_should_fetch_permissions_during_indexing
|
||||
@@ -18,18 +19,25 @@ from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
@@ -49,13 +57,16 @@ from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
@@ -68,7 +79,7 @@ from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
logger = setup_logger(propagate=False)
|
||||
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
@@ -146,6 +157,10 @@ def _get_connector_runner(
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES:
|
||||
logger.warning(
|
||||
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
|
||||
)
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
@@ -180,25 +195,11 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
class RunIndexingContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
should_fetch_permissions_during_indexing: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _check_connector_and_attempt_status(
|
||||
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
|
||||
db_session_temp: Session,
|
||||
cc_pair_id: int,
|
||||
search_settings_status: IndexModelStatus,
|
||||
index_attempt_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Checks the status of the connector credential pair and index attempt.
|
||||
@@ -206,27 +207,34 @@ def _check_connector_and_attempt_status(
|
||||
"""
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.")
|
||||
|
||||
if (
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
and search_settings_status != IndexModelStatus.FUTURE
|
||||
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}")
|
||||
|
||||
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if index_attempt_loop.status == IndexingStatus.CANCELED:
|
||||
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
f"Index Attempt is not running, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
if index_attempt_loop.celery_task_id is None:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id")
|
||||
|
||||
|
||||
# TODO: delete from here if ends up unused
|
||||
def _check_failure_threshold(
|
||||
total_failures: int,
|
||||
document_count: int,
|
||||
@@ -257,6 +265,9 @@ def _check_failure_threshold(
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this is the old run_indexing function that the new decoupled approach
|
||||
# is based on. Leaving this for comparison purposes, but if you see this comment
|
||||
# has been here for >1 month, please delete this function.
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
@@ -271,7 +282,12 @@ def _run_indexing(
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
index_attempt_start = get_index_attempt(
|
||||
db_session_temp,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
@@ -292,7 +308,7 @@ def _run_indexing(
|
||||
index_attempt_start.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = RunIndexingContext(
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
@@ -317,6 +333,7 @@ def _run_indexing(
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None,
|
||||
)
|
||||
|
||||
last_successful_index_poll_range_end = (
|
||||
@@ -384,19 +401,6 @@ def _run_indexing(
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
@@ -416,7 +420,9 @@ def _run_indexing(
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
index_attempt = get_index_attempt(
|
||||
db_session_temp, index_attempt_id, eager_load_cc_pair=True
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
@@ -439,7 +445,7 @@ def _run_indexing(
|
||||
):
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
checkpoint, _ = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
@@ -496,7 +502,10 @@ def _run_indexing(
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
ctx.search_settings_status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
@@ -554,7 +563,16 @@ def _run_indexing(
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
index_pipeline_result = indexing_pipeline(
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
@@ -815,6 +833,7 @@ def _run_indexing(
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
@@ -832,7 +851,6 @@ def run_indexing_entrypoint(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
@@ -846,18 +864,514 @@ def run_indexing_entrypoint(
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"Docfetching starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
connector_document_extraction(
|
||||
app,
|
||||
index_attempt_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
tenant_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"Docfetching finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
|
||||
def connector_document_extraction(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""Extract documents from connector and queue them for indexing pipeline processing.
|
||||
|
||||
This is the first part of the split indexing process that runs the connector
|
||||
and extracts documents, storing them in the filestore for later processing.
|
||||
"""
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
f"Document extraction starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
|
||||
batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt = None
|
||||
last_batch_num = 0 # used to continue from checkpointing
|
||||
# comes from _run_indexing
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError("Search settings must be set for indexing")
|
||||
|
||||
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
|
||||
if index_attempt.connector_credential_pair.indexing_trigger is not None:
|
||||
logger.info(
|
||||
"Clearing indexing trigger: "
|
||||
f"cc_pair={index_attempt.connector_credential_pair.id} "
|
||||
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
|
||||
)
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
index_attempt.connector_credential_pair.id, None, db_session
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
from_beginning = index_attempt.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
)
|
||||
should_fetch_permissions_during_indexing = (
|
||||
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
|
||||
and source_should_fetch_permissions_during_indexing(db_connector.source)
|
||||
and is_primary
|
||||
# if we've already successfully indexed, let the doc_sync job
|
||||
# take care of doc-level permissions
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
)
|
||||
|
||||
# Set up time windows for polling
|
||||
last_successful_index_poll_range_end = (
|
||||
earliest_index_time
|
||||
if from_beginning
|
||||
else get_last_successful_attempt_poll_range_end(
|
||||
cc_pair_id=cc_pair_id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_poll_range_end, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
db_session=db_session,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# set time range in db
|
||||
index_attempt.poll_range_start = window_start
|
||||
index_attempt.poll_range_end = window_end
|
||||
db_session.commit()
|
||||
|
||||
# TODO: maybe memory tracer here
|
||||
|
||||
# Set up connector runner
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=should_fetch_permissions_during_indexing,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling
|
||||
# OR
|
||||
# if the last attempt was successful
|
||||
if index_attempt.from_beginning or (
|
||||
most_recent_attempt and most_recent_attempt.status.is_successful()
|
||||
):
|
||||
logger.info(
|
||||
f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run"
|
||||
)
|
||||
batch_storage.cleanup_all_batches()
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
logger.info(
|
||||
f"Getting latest valid checkpoint for index attempt {index_attempt_id}"
|
||||
)
|
||||
checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
# checkpoint resumption OR the connector already finished.
|
||||
if (
|
||||
isinstance(connector_runner.connector, CheckpointedConnector)
|
||||
and resuming_from_checkpoint
|
||||
) or (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.total_batches is not None
|
||||
and not checkpoint.has_more
|
||||
):
|
||||
reissued_batch_count, completed_batches = reissue_old_batches(
|
||||
batch_storage,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
tenant_id,
|
||||
app,
|
||||
most_recent_attempt,
|
||||
)
|
||||
last_batch_num = reissued_batch_count + completed_batches
|
||||
index_attempt.completed_batches = completed_batches
|
||||
db_session.commit()
|
||||
else:
|
||||
logger.info(
|
||||
f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run"
|
||||
)
|
||||
# for non-checkpointed connectors, throw out batches from previous unsuccessful attempts
|
||||
# because we'll be getting those documents again anyways.
|
||||
batch_storage.cleanup_all_batches()
|
||||
|
||||
# Save initial checkpoint
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
try:
|
||||
batch_num = last_batch_num # starts at 0 if no last batch
|
||||
total_doc_batches_queued = 0
|
||||
total_failures = 0
|
||||
document_count = 0
|
||||
|
||||
# Main extraction loop
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback and callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
with get_session_with_current_tenant() as db_session_tmp:
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_tmp,
|
||||
cc_pair_id,
|
||||
index_attempt.search_settings.status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
failure,
|
||||
db_session,
|
||||
)
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# Save checkpoint if provided
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing task, so if no batch we can just continue
|
||||
if not document_batch:
|
||||
continue
|
||||
|
||||
# Clean documents and create batch
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
batch_description = []
|
||||
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# Store documents in storage
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# Check checkpoint size periodically
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# Save latest checkpoint
|
||||
# NOTE: checkpointing is used to track which batches have
|
||||
# been sent to the filestore, NOT which batches have been fully indexed
|
||||
# as it used to be.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Document extraction completed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"batches_queued={total_doc_batches_queued} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
# Set total batches in database to signal extraction completion.
|
||||
# Used by check_for_indexing to determine if the index attempt is complete.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
IndexingCoordination.set_total_batches(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_batches=batch_num,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
|
||||
# Do NOT clean up batches on failure; future runs will use those batches
|
||||
# while docfetching will continue from the saved checkpoint if one exists
|
||||
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {db_connector.id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise e
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# don't overwrite attempts that are already failed/canceled for another reason
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if index_attempt and index_attempt.status in [
|
||||
IndexingStatus.CANCELED,
|
||||
IndexingStatus.FAILED,
|
||||
]:
|
||||
logger.info(
|
||||
f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed."
|
||||
)
|
||||
raise e
|
||||
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
finally:
|
||||
memory_tracer.stop()
|
||||
|
||||
|
||||
def reissue_old_batches(
|
||||
batch_storage: DocumentBatchStorage,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
app: Celery,
|
||||
most_recent_attempt: IndexAttempt | None,
|
||||
) -> tuple[int, int]:
|
||||
# When loading from a checkpoint, we need to start new docprocessing tasks
|
||||
# tied to the new index attempt for any batches left over in the file store
|
||||
old_batches = batch_storage.get_all_batches_for_cc_pair()
|
||||
batch_storage.update_old_batches_to_new_index_attempt(old_batches)
|
||||
for batch_id in old_batches:
|
||||
logger.info(
|
||||
f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}"
|
||||
)
|
||||
path_info = batch_storage.extract_path_info(batch_id)
|
||||
if path_info is None:
|
||||
continue
|
||||
if path_info.cc_pair_id != cc_pair_id:
|
||||
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
|
||||
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": path_info.batch_num, # use same batch num as previously
|
||||
},
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
|
||||
# resume from the batch num of the last attempt. This should be one more
|
||||
# than the last batch created by docfetching regardless of whether the batch
|
||||
# is still in the filestore waiting for processing or not.
|
||||
last_batch_num = len(old_batches) + recent_batches
|
||||
logger.info(
|
||||
f"Starting from batch {last_batch_num} due to "
|
||||
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
|
||||
)
|
||||
return len(old_batches), recent_batches
|
||||
@@ -1,12 +1,7 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
@@ -16,9 +11,6 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import (
|
||||
run_basic_graph as run_hackathon_graph,
|
||||
) # You can create your own graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -30,11 +22,9 @@ from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import ToolCallFinalResult
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
@@ -54,190 +44,6 @@ logger = setup_logger()
|
||||
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
|
||||
|
||||
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
|
||||
"""
|
||||
Calculate the score for a given position.
|
||||
"""
|
||||
if pos > max_acceptable_pos:
|
||||
return 0
|
||||
|
||||
elif pos == 1:
|
||||
return 1
|
||||
elif pos == 2:
|
||||
return 0.8
|
||||
else:
|
||||
return 4 / (pos + 5)
|
||||
|
||||
|
||||
def _clean_doc_id_link(doc_link: str) -> str:
|
||||
"""
|
||||
Clean the google doc link.
|
||||
"""
|
||||
if "google.com" in doc_link:
|
||||
if "/edit" in doc_link:
|
||||
return "/edit".join(doc_link.split("/edit")[:-1])
|
||||
elif "/view" in doc_link:
|
||||
return "/view".join(doc_link.split("/view")[:-1])
|
||||
else:
|
||||
return doc_link
|
||||
|
||||
if "app.fireflies.ai" in doc_link:
|
||||
return "?".join(doc_link.split("?")[:-1])
|
||||
return doc_link
|
||||
|
||||
|
||||
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
|
||||
"""
|
||||
Get the score of a document from the document results.
|
||||
"""
|
||||
|
||||
match_pos = None
|
||||
for pos, comp_doc in enumerate(doc_results, start=1):
|
||||
|
||||
clear_doc_id = _clean_doc_id_link(doc_id)
|
||||
clear_comp_doc = _clean_doc_id_link(comp_doc)
|
||||
|
||||
if clear_doc_id == clear_comp_doc:
|
||||
match_pos = pos
|
||||
|
||||
if match_pos is None:
|
||||
return 0.0
|
||||
|
||||
return _calc_score_for_pos(match_pos)
|
||||
|
||||
|
||||
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH):
|
||||
"""
|
||||
Append an empty line to the CSV file.
|
||||
"""
|
||||
_append_answer_to_csv("", "", csv_path)
|
||||
|
||||
|
||||
def _append_ground_truth_to_csv(
|
||||
query: str,
|
||||
ground_truth_docs: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for doc_id in ground_truth_docs:
|
||||
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_score_to_csv(
|
||||
query: str,
|
||||
score: float,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", "", score])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_search_results_to_csv(
|
||||
query: str,
|
||||
doc_results: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the search results to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for pos, doc in enumerate(doc_results, start=1):
|
||||
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
|
||||
|
||||
logger.debug("Appended search results to csv file")
|
||||
|
||||
|
||||
def _append_answer_to_csv(
|
||||
query: str,
|
||||
answer: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", answer, ""])
|
||||
|
||||
logger.debug("Appended answer to csv file")
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -328,9 +134,6 @@ class Answer:
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
|
||||
_HACKATHON_TEST_EXECUTION = False
|
||||
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
@@ -351,117 +154,20 @@ class Answer:
|
||||
)
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
|
||||
elif (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.inputs.persona.description.startswith(
|
||||
"Hackathon Test"
|
||||
)
|
||||
):
|
||||
_HACKATHON_TEST_EXECUTION = True
|
||||
run_langgraph = run_hackathon_graph
|
||||
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
if _HACKATHON_TEST_EXECUTION:
|
||||
stream = run_langgraph(self.graph_config)
|
||||
|
||||
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
|
||||
|
||||
if input_data.startswith("["):
|
||||
input_type = "json"
|
||||
input_list = json.loads(input_data)
|
||||
else:
|
||||
input_type = "list"
|
||||
input_list = input_data.split(";")
|
||||
|
||||
num_examples_with_ground_truth = 0
|
||||
total_score = 0.0
|
||||
|
||||
for question_num, question_data in enumerate(input_list):
|
||||
|
||||
ground_truth_docs = None
|
||||
if input_type == "json":
|
||||
question = question_data["question"]
|
||||
ground_truth = question_data.get("ground_truth")
|
||||
if ground_truth:
|
||||
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
|
||||
logger.info(f"Question {question_num}: {question}")
|
||||
_append_ground_truth_to_csv(question, ground_truth_docs)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
question = question_data
|
||||
|
||||
self.graph_config.inputs.prompt_builder.raw_user_query = question
|
||||
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
|
||||
HumanMessage(
|
||||
content=question, additional_kwargs={}, response_metadata={}
|
||||
),
|
||||
2,
|
||||
)
|
||||
self.graph_config.tooling.force_use_tool.force_use = True
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
|
||||
llm_answer_segments: list[str] = []
|
||||
doc_results: list[str] | None = None
|
||||
for answer_piece in processed_stream:
|
||||
if isinstance(answer_piece, OnyxAnswerPiece):
|
||||
llm_answer_segments.append(answer_piece.answer_piece or "")
|
||||
elif isinstance(answer_piece, ToolCallFinalResult):
|
||||
doc_results = [x.get("link") for x in answer_piece.tool_result]
|
||||
|
||||
if doc_results:
|
||||
_append_search_results_to_csv(question, doc_results)
|
||||
|
||||
_append_answer_to_csv(question, "".join(llm_answer_segments))
|
||||
|
||||
if ground_truth_docs and doc_results:
|
||||
num_examples_with_ground_truth += 1
|
||||
doc_score = 0.0
|
||||
for doc_id in ground_truth_docs:
|
||||
doc_score += _get_doc_score(doc_id, doc_results)
|
||||
|
||||
_append_score_to_csv(question, doc_score)
|
||||
total_score += doc_score
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
if num_examples_with_ground_truth > 0:
|
||||
comprehensive_score = total_score / num_examples_with_ground_truth
|
||||
else:
|
||||
comprehensive_score = 0
|
||||
|
||||
_append_empty_line()
|
||||
|
||||
_append_score_to_csv(question, comprehensive_score)
|
||||
|
||||
else:
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
@@ -504,23 +210,6 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]:
|
||||
citations_by_subquestion: dict[SubQuestionKey, list[CitationInfo]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
basic_subq_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if packet.level_question_num is not None and packet.level is not None:
|
||||
citations_by_subquestion[
|
||||
SubQuestionKey(
|
||||
level=packet.level, question_num=packet.level_question_num
|
||||
)
|
||||
].append(packet)
|
||||
elif packet.level is None:
|
||||
citations_by_subquestion[basic_subq_key].append(packet)
|
||||
return citations_by_subquestion
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
@@ -309,7 +309,10 @@ class ContextualPruningConfig(DocumentPruningConfig):
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
return cls(
|
||||
num_chunk_multiple=num_chunk_multiple,
|
||||
**doc_pruning_config.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
@@ -318,9 +321,6 @@ class CitationConfig(BaseModel):
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
@@ -407,7 +407,7 @@ AnswerStream = Iterator[AnswerPacket]
|
||||
|
||||
class AnswerPostInfo(BaseModel):
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
rephrased_query: str | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
|
||||
392
backend/onyx/chat/packet_proccessing/process_streamed_packets.py
Normal file
392
backend/onyx/chat/packet_proccessing/process_streamed_packets.py
Normal file
@@ -0,0 +1,392 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import DefaultDict
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import AgenticMessageResponseIDInfo
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import CustomToolResponse
|
||||
from onyx.chat.models import FileChatDisplay
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.tool_processing import (
|
||||
handle_image_generation_tool_response,
|
||||
)
|
||||
from onyx.chat.packet_proccessing.tool_processing import (
|
||||
handle_internet_search_tool_response,
|
||||
)
|
||||
from onyx.chat.packet_proccessing.tool_processing import (
|
||||
handle_search_tool_response_summary,
|
||||
)
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
CustomToolUserFileSnapshot,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_QUERY_FIELD,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponseSummary,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
COMMON_TOOL_RESPONSE_TYPES = {
|
||||
"image": ChatFileType.IMAGE,
|
||||
"csv": ChatFileType.CSV,
|
||||
}
|
||||
|
||||
# Type definitions for packet processing
|
||||
ChatPacket = Union[
|
||||
StreamingError,
|
||||
QADocsResponse,
|
||||
LLMRelevanceFilterResponse,
|
||||
FinalUsedContextDocsResponse,
|
||||
ChatMessageDetail,
|
||||
AllCitations,
|
||||
CitationInfo,
|
||||
FileChatDisplay,
|
||||
CustomToolResponse,
|
||||
MessageResponseIDInfo,
|
||||
MessageSpecificCitations,
|
||||
AgenticMessageResponseIDInfo,
|
||||
StreamStopInfo,
|
||||
AgentSearchPacket,
|
||||
UserKnowledgeFilePacket,
|
||||
Packet,
|
||||
]
|
||||
|
||||
|
||||
def process_streamed_packets(
|
||||
answer_processed_output: AnswerStream,
|
||||
reserved_message_id: int,
|
||||
selected_db_search_docs: list[DbSearchDoc] | None,
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
db_session: Session,
|
||||
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||
"""Process the streamed output from the answer and yield chat packets."""
|
||||
has_transmitted_answer_piece = False
|
||||
packet_index = 0
|
||||
current_message_index: int | None = None
|
||||
current_tool_index: int | None = None
|
||||
current_citation_index: int | None = None
|
||||
|
||||
# Track ongoing tool operations to prevent concurrent operations of the same type
|
||||
ongoing_search = False
|
||||
ongoing_image_generation = False
|
||||
ongoing_internet_search = False
|
||||
|
||||
# Track citations
|
||||
citations_emitted = False
|
||||
collected_citations: list[CitationInfo] = []
|
||||
|
||||
# Initialize info_by_subq mapping and temp citations storage
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
citations_by_key: DefaultDict[SubQuestionKey, list[CitationInfo]] = defaultdict(
|
||||
list
|
||||
)
|
||||
|
||||
for packet in answer_processed_output:
|
||||
# Determine the sub-question key context when applicable
|
||||
level = getattr(packet, "level", None)
|
||||
level_question_num = getattr(packet, "level_question_num", None)
|
||||
key = SubQuestionKey(
|
||||
level=level if level is not None else BASIC_KEY[0],
|
||||
question_num=(
|
||||
level_question_num if level_question_num is not None else BASIC_KEY[1]
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
info_by_subq[key].tool_result = packet
|
||||
|
||||
# Original packet processing logic continues
|
||||
if isinstance(packet, ToolCallKickoff) and not isinstance(
|
||||
packet, ToolCallFinalResult
|
||||
):
|
||||
# Allocate a new index for this tool call
|
||||
current_tool_index = packet_index
|
||||
packet_index += 1
|
||||
|
||||
# Handle image generation tool start
|
||||
if (
|
||||
packet.tool_name == "run_image_generation"
|
||||
and not ongoing_image_generation
|
||||
):
|
||||
ongoing_image_generation = True
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=ImageGenerationToolStart(),
|
||||
)
|
||||
|
||||
if packet.tool_name == "run_search" and not ongoing_search:
|
||||
ongoing_search = True
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=SearchToolStart(),
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=SearchToolDelta(
|
||||
queries=[packet.tool_args[QUERY_FIELD]],
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
packet.tool_name == "run_internet_search"
|
||||
and not ongoing_internet_search
|
||||
):
|
||||
ongoing_internet_search = True
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=SearchToolStart(
|
||||
is_internet_search=True,
|
||||
),
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=SearchToolDelta(
|
||||
queries=[packet.tool_args[INTERNET_QUERY_FIELD]],
|
||||
),
|
||||
)
|
||||
|
||||
# Fallback: treat unknown tool kickoffs as custom tool start
|
||||
elif packet.tool_name not in {
|
||||
"run_search",
|
||||
"run_internet_search",
|
||||
"run_image_generation",
|
||||
}:
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=CustomToolStart(tool_name=packet.tool_name),
|
||||
)
|
||||
|
||||
elif isinstance(packet, ToolResponse):
|
||||
# Ensure we have a tool index; fallback to current packet_index if needed
|
||||
if current_tool_index is None:
|
||||
current_tool_index = packet_index
|
||||
packet_index += 1
|
||||
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response = cast(SearchResponseSummary, packet.response)
|
||||
saved_search_docs, dropped_inds = (
|
||||
yield from handle_search_tool_response_summary(
|
||||
current_ind=current_tool_index,
|
||||
search_response=search_response,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
is_extended=False,
|
||||
dedupe_docs=bool(
|
||||
retrieval_options and retrieval_options.dedupe_docs
|
||||
),
|
||||
)
|
||||
)
|
||||
info_by_subq[key].reference_db_search_docs = saved_search_docs
|
||||
info_by_subq[key].dropped_indices = dropped_inds
|
||||
ongoing_search = False # Reset search state when tool ends
|
||||
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
|
||||
internet_response = cast(InternetSearchResponseSummary, packet.response)
|
||||
saved_internet_docs = yield from handle_internet_search_tool_response(
|
||||
current_tool_index, internet_response
|
||||
)
|
||||
info_by_subq[key].reference_db_search_docs = saved_internet_docs
|
||||
ongoing_internet_search = False
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
yield from handle_image_generation_tool_response(
|
||||
current_tool_index, img_generation_response
|
||||
)
|
||||
ongoing_image_generation = (
|
||||
False # Reset image generation state when tool ends
|
||||
)
|
||||
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
summary = cast(CustomToolCallSummary, packet.response)
|
||||
# Emit start if not already started for this index
|
||||
# We emit start once per custom tool index
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=CustomToolStart(tool_name=summary.tool_name),
|
||||
)
|
||||
|
||||
# Decide whether we have file outputs or data
|
||||
file_ids: list[str] | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
if summary.response_type in ("image", "csv"):
|
||||
try:
|
||||
snapshot = cast(CustomToolUserFileSnapshot, summary.tool_result)
|
||||
file_ids = snapshot.file_ids
|
||||
except Exception:
|
||||
file_ids = None
|
||||
else:
|
||||
data = summary.tool_result # type: ignore[assignment]
|
||||
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=summary.tool_name,
|
||||
response_type=summary.response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
)
|
||||
|
||||
# End this tool section
|
||||
yield Packet(
|
||||
ind=current_tool_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
yield packet
|
||||
elif isinstance(packet, OnyxAnswerPiece):
|
||||
if has_transmitted_answer_piece:
|
||||
if packet.answer_piece is None:
|
||||
# Message is ending, use current message index
|
||||
if current_message_index is not None:
|
||||
yield Packet(
|
||||
ind=current_message_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
# Reset for next message
|
||||
current_message_index = None
|
||||
has_transmitted_answer_piece = False
|
||||
else:
|
||||
# Continue with same index for message delta
|
||||
if current_message_index is not None:
|
||||
yield Packet(
|
||||
ind=current_message_index,
|
||||
obj=MessageDelta(
|
||||
content=packet.answer_piece or "",
|
||||
),
|
||||
)
|
||||
|
||||
elif packet.answer_piece:
|
||||
# New message starting, allocate new index
|
||||
current_message_index = packet_index
|
||||
packet_index += 1
|
||||
yield Packet(
|
||||
ind=current_message_index,
|
||||
obj=MessageStart(
|
||||
id=str(reserved_message_id),
|
||||
content=packet.answer_piece,
|
||||
),
|
||||
)
|
||||
has_transmitted_answer_piece = True
|
||||
elif isinstance(packet, CitationInfo):
|
||||
# Collect citations for batch processing
|
||||
if not citations_emitted:
|
||||
# First citation - allocate index but don't emit yet
|
||||
if current_citation_index is None:
|
||||
current_citation_index = packet_index
|
||||
packet_index += 1
|
||||
|
||||
# Collect citation info
|
||||
collected_citations.append(
|
||||
CitationInfo(
|
||||
citation_num=packet.citation_num,
|
||||
document_id=packet.document_id,
|
||||
)
|
||||
)
|
||||
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
if current_message_index is not None:
|
||||
yield Packet(ind=current_message_index, obj=SectionEnd())
|
||||
|
||||
# Emit collected citations if any
|
||||
if collected_citations and current_citation_index is not None:
|
||||
yield Packet(ind=current_citation_index, obj=CitationStart())
|
||||
yield Packet(
|
||||
ind=current_citation_index, obj=CitationDelta(citations=collected_citations)
|
||||
)
|
||||
yield Packet(
|
||||
ind=current_citation_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
|
||||
# Yield STOP packet to indicate streaming is complete
|
||||
yield Packet(ind=packet_index, obj=OverallStop())
|
||||
|
||||
# Build citation maps per sub-question key using available docs
|
||||
for key, citation_list in citations_by_key.items():
|
||||
info = info_by_subq[key]
|
||||
if not citation_list:
|
||||
continue
|
||||
|
||||
doc_id_to_saved_db_id = {
|
||||
doc.document_id: doc.id for doc in info.reference_db_search_docs or []
|
||||
}
|
||||
|
||||
citation_map: dict[int, int] = {}
|
||||
for c in citation_list:
|
||||
mapped_db_id = doc_id_to_saved_db_id.get(c.document_id)
|
||||
if mapped_db_id is not None and c.citation_num not in citation_map:
|
||||
citation_map[c.citation_num] = mapped_db_id
|
||||
|
||||
if citation_map:
|
||||
info.message_specific_citations = MessageSpecificCitations(
|
||||
citation_map=citation_map
|
||||
)
|
||||
|
||||
return info_by_subq
|
||||
164
backend/onyx/chat/packet_proccessing/tool_processing.py
Normal file
164
backend/onyx/chat/packet_proccessing/tool_processing.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.chat import create_search_doc_from_user_file
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.models import (
|
||||
InternetSearchResponseSummary,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.utils import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
|
||||
|
||||
def handle_search_tool_response_summary(
|
||||
current_ind: int,
|
||||
search_response: SearchResponseSummary,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
is_extended: bool,
|
||||
dedupe_docs: bool = False,
|
||||
user_files: list[UserFile] | None = None,
|
||||
loaded_user_files: list[InMemoryChatFile] | None = None,
|
||||
) -> Generator[Packet, None, tuple[list[DbSearchDoc], list[int] | None]]:
|
||||
dropped_inds = None
|
||||
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(search_response.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if (
|
||||
dedupe_docs and not is_extended
|
||||
): # Extended tool responses are already deduped
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None and loaded_user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id in doc_ids:
|
||||
continue
|
||||
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SearchToolDelta(
|
||||
documents=response_docs,
|
||||
),
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
|
||||
return reference_db_search_docs, dropped_inds
|
||||
|
||||
|
||||
def handle_internet_search_tool_response(
|
||||
current_ind: int,
|
||||
internet_search_response: InternetSearchResponseSummary,
|
||||
) -> Generator[Packet, None, list[DbSearchDoc]]:
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SearchToolDelta(
|
||||
documents=response_docs,
|
||||
),
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
|
||||
return reference_db_search_docs
|
||||
|
||||
|
||||
def handle_image_generation_tool_response(
|
||||
current_ind: int,
|
||||
img_generation_responses: list[ImageGenerationResponse],
|
||||
) -> Generator[Packet, None, None]:
|
||||
|
||||
# Save files and get file IDs
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_responses if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in img_generation_responses if img.image_data
|
||||
],
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=ImageGenerationToolDelta(
|
||||
images=[
|
||||
{
|
||||
"id": str(file_id),
|
||||
"url": "", # URL will be constructed by frontend
|
||||
"prompt": img.revised_prompt,
|
||||
}
|
||||
for file_id, img in zip(file_ids, img_generation_responses)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Emit ImageToolEnd packet with file information
|
||||
yield Packet(
|
||||
ind=current_ind,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
@@ -17,30 +16,26 @@ from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.chat_utils import process_kg_commands
|
||||
from onyx.chat.models import AgenticMessageResponseIDInfo
|
||||
from onyx.chat.models import AgentMessageIDInfo
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import CustomToolResponse
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FileChatDisplay
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import ChatPacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import (
|
||||
process_streamed_packets,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
@@ -55,21 +50,13 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
inference_sections_from_ids,
|
||||
)
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.context.search.utils import drop_llm_indices
|
||||
from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import attach_files_to_chat_message
|
||||
from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import create_search_doc_from_user_file
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_db_search_doc_by_id
|
||||
@@ -77,7 +64,6 @@ from onyx.db.chat import get_doc_query_identifiers_from_model
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
@@ -88,15 +74,12 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
@@ -110,47 +93,16 @@ from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from onyx.tools.tool_constructor import InternetSearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -201,113 +153,6 @@ def _translate_citations(
|
||||
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
|
||||
|
||||
|
||||
def _handle_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
dedupe_docs: bool = False,
|
||||
user_files: list[UserFile] | None = None,
|
||||
loaded_user_files: list[InMemoryChatFile] | None = None,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
is_extended = isinstance(packet, ExtendedToolResponse)
|
||||
dropped_inds = None
|
||||
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_summary.top_sections)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if (
|
||||
dedupe_docs and not is_extended
|
||||
): # Extended tool responses are already deduped
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None and loaded_user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id in doc_ids:
|
||||
continue
|
||||
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
level, question_num = None, None
|
||||
if isinstance(packet, ExtendedToolResponse):
|
||||
level, question_num = packet.level, packet.level_question_num
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=response_summary.rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=response_summary.predicted_flow,
|
||||
predicted_search=response_summary.predicted_search,
|
||||
applied_source_filters=response_summary.final_filters.source_type,
|
||||
applied_time_cutoff=response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=response_summary.recency_bias_multiplier,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
)
|
||||
|
||||
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponse, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.SEMANTIC,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
tools: list[Tool],
|
||||
@@ -392,136 +237,9 @@ def _get_persona_for_chat_session(
|
||||
return persona
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
| OnyxAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| AgenticMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
| UserKnowledgeFilePacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
def _process_tool_response(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_db_search_docs: list[DbSearchDoc] | None,
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
user_file_files: list[UserFile] | None,
|
||||
user_files: list[InMemoryChatFile] | None,
|
||||
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=bool(retrieval_options and retrieval_options.dedupe_docs),
|
||||
user_files=[],
|
||||
loaded_user_files=[],
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning("No reference docs found for relevance filtering")
|
||||
return info_by_subq
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in img_generation_response if img.image_data
|
||||
],
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
response_type = custom_tool_response.response_type
|
||||
if response_type in COMMON_TOOL_RESPONSE_TYPES:
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=file_type)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
return info_by_subq
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -561,6 +279,7 @@ def stream_chat_message_objects(
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
llm: LLM
|
||||
answer: Answer
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
@@ -725,9 +444,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
user_file_ids = new_msg_req.user_file_ids or []
|
||||
@@ -906,7 +623,6 @@ def stream_chat_message_objects(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
@@ -936,6 +652,7 @@ def stream_chat_message_objects(
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
@@ -985,7 +702,6 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
is_connected=is_connected,
|
||||
@@ -1012,43 +728,17 @@ def stream_chat_message_objects(
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
# Process streamed packets using the new packet processing module
|
||||
info_by_subq = yield from process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
reserved_message_id=reserved_message_id,
|
||||
selected_db_search_docs=selected_db_search_docs,
|
||||
retrieval_options=retrieval_options,
|
||||
db_session=db_session,
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
info_by_subq = yield from _process_tool_response(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_db_search_docs=selected_db_search_docs,
|
||||
info_by_subq=info_by_subq,
|
||||
retrieval_options=retrieval_options,
|
||||
user_file_files=user_file_models,
|
||||
user_files=in_memory_user_files,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
yield packet
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
refined_answer_improvement = packet.refined_answer_improvement
|
||||
yield packet
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if packet.level is not None
|
||||
and packet.level_question_num is not None
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
@@ -1092,7 +782,6 @@ def stream_chat_message_objects(
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
)
|
||||
|
||||
|
||||
@@ -1104,7 +793,6 @@ def _post_llm_answer_processing(
|
||||
llm_tokenizer_encode_func: Callable[[str], list[int]],
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
refined_answer_improvement: bool | None,
|
||||
) -> Generator[ChatPacket, None, None]:
|
||||
"""
|
||||
Stores messages in the db and yields some final packets to the frontend
|
||||
@@ -1116,20 +804,6 @@ def _post_llm_answer_processing(
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
subq_citations = answer.citations_by_subquestion()
|
||||
for subq_key in subq_citations:
|
||||
info = info_by_subq[subq_key]
|
||||
logger.debug("Post-LLM answer processing")
|
||||
if info.reference_db_search_docs:
|
||||
info.message_specific_citations = _translate_citations(
|
||||
citations_list=subq_citations[subq_key],
|
||||
db_docs=info.reference_db_search_docs,
|
||||
)
|
||||
|
||||
# TODO: AllCitations should contain subq info?
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=subq_citations[subq_key])
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
@@ -1145,9 +819,7 @@ def _post_llm_answer_processing(
|
||||
)
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
|
||||
),
|
||||
rephrased_query=info.rephrased_query,
|
||||
reference_docs=info.reference_db_search_docs,
|
||||
files=info.ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
@@ -1206,7 +878,6 @@ def _post_llm_answer_processing(
|
||||
else None
|
||||
),
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=True,
|
||||
)
|
||||
agentic_message_ids.append(
|
||||
|
||||
@@ -187,12 +187,8 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if (
|
||||
self.new_messages_and_token_cnts
|
||||
and isinstance(self.user_message_and_token_cnt[0].content, str)
|
||||
and self.user_message_and_token_cnt[0].content.startswith("Refer")
|
||||
):
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.chat.models import (
|
||||
)
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_document_tokens
|
||||
from onyx.configs.app_configs import MAX_FEDERATED_SECTIONS
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -67,6 +68,38 @@ def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
return combined_ranges
|
||||
|
||||
|
||||
def _separate_federated_sections(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
) -> tuple[list[InferenceSection], list[InferenceSection], list[bool] | None]:
|
||||
"""
|
||||
Separates out the first NUM_FEDERATED_SECTIONS federated sections to be spared from pruning.
|
||||
Any remaining federated sections are treated as normal sections, and will get added if it
|
||||
fits within the allocated context window. This is done as federated sections do not have
|
||||
a score and would otherwise always get pruned.
|
||||
"""
|
||||
federated_sections: list[InferenceSection] = []
|
||||
normal_sections: list[InferenceSection] = []
|
||||
normal_section_relevance_list: list[bool] = []
|
||||
|
||||
for i, section in enumerate(sections):
|
||||
if (
|
||||
len(federated_sections) < MAX_FEDERATED_SECTIONS
|
||||
and section.center_chunk.is_federated
|
||||
):
|
||||
federated_sections.append(section)
|
||||
continue
|
||||
normal_sections.append(section)
|
||||
if section_relevance_list is not None:
|
||||
normal_section_relevance_list.append(section_relevance_list[i])
|
||||
|
||||
return (
|
||||
federated_sections[:MAX_FEDERATED_SECTIONS],
|
||||
normal_sections,
|
||||
normal_section_relevance_list if section_relevance_list is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def _compute_limit(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
@@ -105,7 +138,7 @@ def _compute_limit(
|
||||
return int(min(limit_options))
|
||||
|
||||
|
||||
def reorder_sections(
|
||||
def _reorder_sections(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
) -> list[InferenceSection]:
|
||||
@@ -113,11 +146,10 @@ def reorder_sections(
|
||||
return sections
|
||||
|
||||
reordered_sections: list[InferenceSection] = []
|
||||
if section_relevance_list is not None:
|
||||
for selection_target in [True, False]:
|
||||
for section, is_relevant in zip(sections, section_relevance_list):
|
||||
if is_relevant == selection_target:
|
||||
reordered_sections.append(section)
|
||||
for selection_target in [True, False]:
|
||||
for section, is_relevant in zip(sections, section_relevance_list):
|
||||
if is_relevant == selection_target:
|
||||
reordered_sections.append(section)
|
||||
return reordered_sections
|
||||
|
||||
|
||||
@@ -134,6 +166,7 @@ def _remove_sections_to_ignore(
|
||||
def _apply_pruning(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
keep_sections: list[InferenceSection],
|
||||
token_limit: int,
|
||||
is_manually_selected_docs: bool,
|
||||
use_sections: bool,
|
||||
@@ -144,10 +177,22 @@ def _apply_pruning(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
sections = deepcopy(sections) # don't modify in place
|
||||
|
||||
# combine the section lists, making sure to add the keep_sections first
|
||||
sections = deepcopy(keep_sections) + deepcopy(sections)
|
||||
|
||||
# build combined relevance list, treating the keep_sections as relevant
|
||||
if section_relevance_list is not None:
|
||||
section_relevance_list = [True] * len(keep_sections) + section_relevance_list
|
||||
|
||||
# map unique_id: relevance for final ordering step
|
||||
section_id_to_relevance: dict[str, bool] = {}
|
||||
if section_relevance_list is not None:
|
||||
for sec, rel in zip(sections, section_relevance_list):
|
||||
section_id_to_relevance[sec.center_chunk.unique_id] = rel
|
||||
|
||||
# re-order docs with all the "relevant" docs at the front
|
||||
sections = reorder_sections(
|
||||
sections = _reorder_sections(
|
||||
sections=sections, section_relevance_list=section_relevance_list
|
||||
)
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
@@ -274,6 +319,14 @@ def _apply_pruning(
|
||||
)
|
||||
sections = [sections[0]]
|
||||
|
||||
# sort by relevance, then by score (as we added the keep_sections first)
|
||||
sections.sort(
|
||||
key=lambda s: (
|
||||
not section_id_to_relevance.get(s.center_chunk.unique_id, True),
|
||||
-(s.center_chunk.score or 0.0),
|
||||
),
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
@@ -289,9 +342,16 @@ def prune_sections(
|
||||
if section_relevance_list is not None:
|
||||
assert len(sections) == len(section_relevance_list)
|
||||
|
||||
# get federated sections (up to NUM_FEDERATED_SECTIONS)
|
||||
# TODO: if we can somehow score the federated sections well, we don't need this
|
||||
federated_sections, normal_sections, normal_section_relevance_list = (
|
||||
_separate_federated_sections(sections, section_relevance_list)
|
||||
)
|
||||
|
||||
actual_num_chunks = (
|
||||
contextual_pruning_config.max_chunks
|
||||
* contextual_pruning_config.num_chunk_multiple
|
||||
+ len(federated_sections)
|
||||
if contextual_pruning_config.max_chunks
|
||||
else None
|
||||
)
|
||||
@@ -307,8 +367,9 @@ def prune_sections(
|
||||
)
|
||||
|
||||
return _apply_pruning(
|
||||
sections=sections,
|
||||
section_relevance_list=section_relevance_list,
|
||||
sections=normal_sections,
|
||||
section_relevance_list=normal_section_relevance_list,
|
||||
keep_sections=federated_sections,
|
||||
token_limit=token_limit,
|
||||
is_manually_selected_docs=contextual_pruning_config.is_manually_selected_docs,
|
||||
use_sections=contextual_pruning_config.use_sections, # Now default True
|
||||
|
||||
@@ -35,6 +35,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
# Controls whether users can use User Knowledge (personal documents) in assistants
|
||||
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
|
||||
|
||||
# Controls whether to allow admin query history reports with:
|
||||
# 1. associated user emails
|
||||
# 2. anonymized user emails
|
||||
@@ -308,25 +311,40 @@ except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_INDEXING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
|
||||
env_value = str(CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = (
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 1024
|
||||
VESPA_SYNC_MAX_TASKS = 8192
|
||||
|
||||
DB_YIELD_PER_DEFAULT = 64
|
||||
|
||||
@@ -450,6 +468,11 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# Default size threshold for SharePoint files (20MB)
|
||||
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -481,6 +504,7 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
# Slack specific configs
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
|
||||
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
@@ -654,6 +678,14 @@ except json.JSONDecodeError:
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
# Federated Search Configs
|
||||
MAX_FEDERATED_SECTIONS = int(
|
||||
os.environ.get("MAX_FEDERATED_SECTIONS", "5")
|
||||
) # max no. of federated sections to always keep
|
||||
MAX_FEDERATED_CHUNKS = int(
|
||||
os.environ.get("MAX_FEDERATED_CHUNKS", "5")
|
||||
) # max no. of chunks to retrieve per federated connector
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
@@ -787,7 +819,3 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
|
||||
# Forcing Vespa Language
|
||||
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
|
||||
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
|
||||
|
||||
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
|
||||
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
|
||||
)
|
||||
|
||||
@@ -91,6 +91,10 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
|
||||
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
@@ -65,7 +65,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
@@ -121,6 +122,8 @@ CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
|
||||
# hard termination should always fire first if the connector is hung
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
|
||||
|
||||
# Heartbeat interval for indexing worker liveness detection
|
||||
INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
@@ -186,10 +189,21 @@ class DocumentSource(str, Enum):
|
||||
AIRTABLE = "airtable"
|
||||
HIGHSPOT = "highspot"
|
||||
|
||||
IMAP = "imap"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
|
||||
|
||||
class FederatedConnectorSource(str, Enum):
|
||||
FEDERATED_SLACK = "federated_slack"
|
||||
|
||||
def to_non_federated_source(self) -> DocumentSource | None:
|
||||
if self == FederatedConnectorSource.FEDERATED_SLACK:
|
||||
return DocumentSource.SLACK
|
||||
return None
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
|
||||
@@ -320,9 +334,12 @@ class OnyxCeleryQueues:
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# Document processing pipeline queue
|
||||
DOCPROCESSING = "docprocessing"
|
||||
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
@@ -453,7 +470,11 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
|
||||
# New split indexing tasks
|
||||
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
|
||||
DOCPROCESSING_TASK = "docprocessing_task"
|
||||
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
# TODO: Refactor to avoid direct DB access in connector
|
||||
# This will require broader refactoring across the codebase
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
image_section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=downloaded_file,
|
||||
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
|
||||
display_name=file_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=downloaded_file,
|
||||
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
|
||||
display_name=file_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=[image_section],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=[image_section],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
except Exception:
|
||||
logger.exception(f"Error processing image {key}")
|
||||
continue
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
"""
|
||||
# README (notes on Confluence pagination):
|
||||
|
||||
We've noticed that the `search/users` and `users/memberof` endpoints for Confluence Cloud use offset-based pagination as
|
||||
opposed to cursor-based. We also know that page-retrieval uses cursor-based pagination.
|
||||
|
||||
Our default pagination strategy right now for cloud is to assume cursor-based.
|
||||
However, if you notice that a cloud API is not being properly paginated (i.e., if the `_links.next` is not appearing in the
|
||||
returned payload), then you can force offset-based pagination.
|
||||
|
||||
# TODO (@raunakab)
|
||||
We haven't explored all of the cloud APIs' pagination strategies. @raunakab take time to go through this and figure them out.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -46,16 +60,13 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
|
||||
class OnyxConfluence:
|
||||
"""
|
||||
This is a custom Confluence class that:
|
||||
@@ -311,8 +322,8 @@ class OnyxConfluence:
|
||||
return confluence
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
# This uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling.
|
||||
def _make_rate_limited_confluence_method(
|
||||
self, name: str, credential_provider: CredentialsProviderInterface | None
|
||||
) -> Callable[..., Any]:
|
||||
@@ -378,25 +389,6 @@ class OnyxConfluence:
|
||||
|
||||
return wrapped_call
|
||||
|
||||
# def _wrap_methods(self) -> None:
|
||||
# """
|
||||
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
# wrap it with handle_confluence_rate_limit.
|
||||
# """
|
||||
# for attr_name in dir(self):
|
||||
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
# setattr(
|
||||
# self,
|
||||
# attr_name,
|
||||
# handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# )
|
||||
|
||||
# def _ensure_token_valid(self) -> None:
|
||||
# if self._token_is_expired():
|
||||
# self._refresh_token()
|
||||
# # Re-init the Confluence client with the originally stored args
|
||||
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamically intercept attribute/method access."""
|
||||
attr = getattr(self._confluence, name, None)
|
||||
@@ -483,6 +475,7 @@ class OnyxConfluence:
|
||||
limit: int | None = None,
|
||||
# Called with the next url to use to get the next page
|
||||
next_page_callback: Callable[[str], None] | None = None,
|
||||
force_offset_pagination: bool = False,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
@@ -498,6 +491,10 @@ class OnyxConfluence:
|
||||
raw_response = self.get(
|
||||
path=url_suffix,
|
||||
advanced_mode=True,
|
||||
params={
|
||||
"body-format": "atlas_doc_format",
|
||||
"expand": "body.atlas_doc_format",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
@@ -564,14 +561,32 @@ class OnyxConfluence:
|
||||
)
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
# Yield the results individually.
|
||||
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||
# make sure we don't update the start by more than the amount
|
||||
|
||||
# Note 1:
|
||||
# Make sure we don't update the start by more than the amount
|
||||
# of results we were able to retrieve. The Confluence API has a
|
||||
# weird behavior where if you pass in a limit that is too large for
|
||||
# the configured server, it will artificially limit the amount of
|
||||
# results returned BUT will not apply this to the start parameter.
|
||||
# This will cause us to miss results.
|
||||
#
|
||||
# Note 2:
|
||||
# We specifically perform manual yielding (i.e., `for x in xs: yield x`) as opposed to using a `yield from xs`
|
||||
# because we *have to call the `next_page_callback`* prior to yielding the last element!
|
||||
#
|
||||
# If we did:
|
||||
#
|
||||
# ```py
|
||||
# yield from results
|
||||
# if next_page_callback:
|
||||
# next_page_callback(url_suffix)
|
||||
# ```
|
||||
#
|
||||
# then the logic would fail since the iterator would finish (and the calling scope would exit out of its driving
|
||||
# loop) prior to the callback being called.
|
||||
|
||||
old_url_suffix = url_suffix
|
||||
updated_start = get_start_param_from_url(old_url_suffix)
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
@@ -587,6 +602,12 @@ class OnyxConfluence:
|
||||
)
|
||||
# notify the caller of the new url
|
||||
next_page_callback(url_suffix)
|
||||
|
||||
elif force_offset_pagination and i == len(results) - 1:
|
||||
url_suffix = update_param_in_path(
|
||||
old_url_suffix, "start", str(updated_start)
|
||||
)
|
||||
|
||||
yield result
|
||||
|
||||
# we've observed that Confluence sometimes returns a next link despite giving
|
||||
@@ -684,7 +705,9 @@ class OnyxConfluence:
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
for user_result in self._paginate_url(url, limit):
|
||||
for user_result in self._paginate_url(
|
||||
url, limit, force_offset_pagination=True
|
||||
):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
@@ -774,7 +797,7 @@ class OnyxConfluence:
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
||||
url = f"rest/api/user/memberof?{user_query}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
yield from self._paginate_url(url, limit, force_offset_pagination=True)
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
@@ -926,6 +949,9 @@ def extract_text_from_confluence_html(
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
|
||||
_remove_macro_stylings(soup=soup)
|
||||
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
@@ -1007,3 +1033,15 @@ def extract_text_from_confluence_html(
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def _remove_macro_stylings(soup: bs4.BeautifulSoup) -> None:
|
||||
for macro_root in soup.findAll("ac:structured-macro"):
|
||||
if not isinstance(macro_root, bs4.Tag):
|
||||
continue
|
||||
|
||||
macro_styling = macro_root.find(name="ac:parameter", attrs={"ac:name": "page"})
|
||||
if not macro_styling or not isinstance(macro_styling, bs4.Tag):
|
||||
continue
|
||||
|
||||
macro_styling.extract()
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
@@ -224,19 +223,17 @@ def _process_image_attachment(
|
||||
"""Process an image attachment by saving it without generating a summary."""
|
||||
try:
|
||||
# Use the standardized image storage and section creation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
section, file_name = store_image_and_create_section(
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
except Exception as e:
|
||||
msg = f"Image storage failed for {attachment['title']}: {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
|
||||
@@ -27,7 +27,8 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||
specifically for Document outputs.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
@@ -131,7 +132,7 @@ class ConnectorRunner(Generic[CT]):
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
if document is not None and isinstance(document, Document):
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
|
||||
@@ -109,8 +109,10 @@ def process_onyx_metadata(
|
||||
|
||||
return (
|
||||
OnyxMetadata(
|
||||
source_type=metadata.get("connector_type"),
|
||||
link=metadata.get("link"),
|
||||
file_display_name=metadata.get("file_display_name"),
|
||||
title=metadata.get("title"),
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
doc_updated_at=doc_updated_at,
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.imap.connector import ImapConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
@@ -121,6 +122,7 @@ def identify_connector_class(
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
DocumentSource.IMAP: ImapConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
@@ -32,7 +29,6 @@ logger = setup_logger()
|
||||
|
||||
def _create_image_section(
|
||||
image_data: bytes,
|
||||
db_session: Session,
|
||||
parent_file_name: str,
|
||||
display_name: str,
|
||||
link: str | None = None,
|
||||
@@ -58,7 +54,6 @@ def _create_image_section(
|
||||
# Store the image and create a section
|
||||
try:
|
||||
section, stored_file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=image_data,
|
||||
file_id=file_id,
|
||||
display_name=display_name,
|
||||
@@ -77,7 +72,6 @@ def _process_file(
|
||||
file: IO[Any],
|
||||
metadata: dict[str, Any] | None,
|
||||
pdf_pass: str | None,
|
||||
db_session: Session,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Process a file and return a list of Documents.
|
||||
@@ -125,7 +119,6 @@ def _process_file(
|
||||
try:
|
||||
section, _ = _create_image_section(
|
||||
image_data=image_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=file_id,
|
||||
display_name=title,
|
||||
)
|
||||
@@ -171,10 +164,12 @@ def _process_file(
|
||||
custom_tags.update(more_custom_tags)
|
||||
|
||||
# File-specific metadata overrides metadata processed so far
|
||||
source_type = onyx_metadata.source_type or source_type
|
||||
primary_owners = onyx_metadata.primary_owners or primary_owners
|
||||
secondary_owners = onyx_metadata.secondary_owners or secondary_owners
|
||||
time_updated = onyx_metadata.doc_updated_at or time_updated
|
||||
file_display_name = onyx_metadata.file_display_name or file_display_name
|
||||
title = onyx_metadata.title or onyx_metadata.file_display_name or title
|
||||
link = onyx_metadata.link or link
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
@@ -194,7 +189,6 @@ def _process_file(
|
||||
try:
|
||||
image_section, stored_file_name = _create_image_section(
|
||||
image_data=img_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=file_id,
|
||||
display_name=f"{title} - image {idx}",
|
||||
idx=idx,
|
||||
@@ -228,11 +222,21 @@ class LocalFileConnector(LoadConnector):
|
||||
"""
|
||||
Connector that reads files from Postgres and yields Documents, including
|
||||
embedded image extraction without summarization.
|
||||
|
||||
file_locations are S3/Filestore UUIDs
|
||||
file_names are the names of the files
|
||||
"""
|
||||
|
||||
# Note: file_names is a required parameter, but should not break backwards compatibility.
|
||||
# If add_file_names migration is not run, old file connector configs will not have file_names.
|
||||
# This is fine because the configs are not re-used to instantiate the connector.
|
||||
# file_names is only used for display purposes in the UI and file_locations is used as a fallback.
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
file_names: list[
|
||||
str
|
||||
], # Must accept this parameter as connector_specific_config is unpacked as args
|
||||
zip_metadata: dict[str, Any],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -258,41 +262,39 @@ class LocalFileConnector(LoadConnector):
|
||||
"""
|
||||
documents: list[Document] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id=file_id)
|
||||
if not file_record:
|
||||
# typically an unsupported extension
|
||||
logger.warning(
|
||||
f"No file record found for '{file_id}' in PG; skipping."
|
||||
)
|
||||
continue
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(file_id=file_id)
|
||||
if not file_record:
|
||||
# typically an unsupported extension
|
||||
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
|
||||
continue
|
||||
|
||||
metadata = self._get_file_metadata(file_id)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
file_name=file_record.display_name,
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
db_session=db_session,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
metadata = self._get_file_metadata(file_record.display_name)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
file_name=file_record.display_name,
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[os.environ["TEST_FILE"]], zip_metadata={}
|
||||
file_locations=[os.environ["TEST_FILE"]],
|
||||
file_names=[os.environ["TEST_FILE"]],
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
|
||||
doc_batches = connector.load_from_state()
|
||||
|
||||
@@ -35,6 +35,7 @@ _FIREFLIES_API_QUERY = """
|
||||
organizer_email
|
||||
participants
|
||||
date
|
||||
duration
|
||||
transcript_url
|
||||
sentences {
|
||||
text
|
||||
@@ -101,7 +102,14 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
metadata={
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
"meeting_date": meeting_date,
|
||||
"duration_min": transcript.get("duration"),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
doc_updated_at=meeting_date,
|
||||
primary_owners=organizer_email_user_info,
|
||||
secondary_owners=meeting_participants_email_list,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
@@ -17,17 +16,22 @@ from github.Issue import Issue
|
||||
from github.NamedUser import NamedUser
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
|
||||
from onyx.connectors.github.utils import deserialize_repository
|
||||
from onyx.connectors.github.utils import get_external_access_permission
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
@@ -46,17 +50,7 @@ CURSOR_LOG_FREQUENCY = 50
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
ONE_DAY = timedelta(days=1)
|
||||
|
||||
|
||||
def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
sleep_time = github_client.get_rate_limit().core.reset.replace(
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
time.sleep(sleep_time.seconds)
|
||||
|
||||
|
||||
SLIM_BATCH_SIZE = 100
|
||||
# Cases
|
||||
# X (from start) standard run, no fallback to cursor-based pagination
|
||||
# X (from start) standard run errors, fallback to cursor-based pagination
|
||||
@@ -72,6 +66,10 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
# checkpoint progress (no infinite loop)
|
||||
|
||||
|
||||
class DocMetadata(BaseModel):
|
||||
repo: str
|
||||
|
||||
|
||||
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
|
||||
if "_PaginatedList__nextUrl" in pag_list.__dict__:
|
||||
return "_PaginatedList__nextUrl"
|
||||
@@ -190,7 +188,7 @@ def _get_batch_rate_limited(
|
||||
getattr(obj, "raw_data")
|
||||
yield from objs
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
yield from _get_batch_rate_limited(
|
||||
git_objs,
|
||||
page_num,
|
||||
@@ -232,12 +230,17 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
def _convert_pr_to_document(
|
||||
pull_request: PullRequest, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
sections=[
|
||||
TextSection(link=pull_request.html_url, text=pull_request.body or "")
|
||||
],
|
||||
external_access=repo_external_access,
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
@@ -248,6 +251,8 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
if pull_request.updated_at
|
||||
else None
|
||||
),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -301,14 +306,21 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
return "\nComment: ".join(comment.body for comment in comments)
|
||||
|
||||
|
||||
def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
def _convert_issue_to_document(
|
||||
issue: Issue, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = issue.repository.full_name if issue.repository else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
return Document(
|
||||
id=issue.html_url,
|
||||
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
external_access=repo_external_access,
|
||||
semantic_identifier=f"{issue.number}: {issue.title}",
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -343,18 +355,6 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
)
|
||||
|
||||
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
|
||||
|
||||
class GithubConnectorStage(Enum):
|
||||
START = "start"
|
||||
PRS = "prs"
|
||||
@@ -394,7 +394,7 @@ def make_cursor_url_callback(
|
||||
return cursor_url_callback
|
||||
|
||||
|
||||
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
@@ -423,7 +423,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_github_repo(
|
||||
def get_github_repo(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> Repository.Repository:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
@@ -434,10 +434,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
try:
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_github_repos(
|
||||
def get_github_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
"""Get specific repositories based on comma-separated repo_name string."""
|
||||
@@ -465,10 +465,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
|
||||
return repos
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repos(github_client, attempt_num + 1)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_github_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
def get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
@@ -487,8 +487,8 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
user = github_client.get_user(self.repo_owner)
|
||||
return list(user.get_repos())
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _pull_requests_func(
|
||||
self, repo: Repository.Repository
|
||||
@@ -509,6 +509,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
@@ -521,13 +522,13 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
repos = self.get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
repos = [self.get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
repos = self.get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
@@ -547,28 +548,15 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
if checkpoint.cached_repo is None:
|
||||
raise ValueError("No repo saved in checkpoint")
|
||||
|
||||
# Try to access the requester - different PyGithub versions may use different attribute names
|
||||
try:
|
||||
# Try direct access to a known attribute name first
|
||||
if hasattr(self.github_client, "_requester"):
|
||||
requester = self.github_client._requester
|
||||
elif hasattr(self.github_client, "_Github__requester"):
|
||||
requester = self.github_client._Github__requester
|
||||
else:
|
||||
# If we can't find the requester attribute, we need to fall back to recreating the repo
|
||||
raise AttributeError("Could not find requester attribute")
|
||||
|
||||
repo = checkpoint.cached_repo.to_Repository(requester)
|
||||
except Exception as e:
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logger.warning(
|
||||
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
|
||||
)
|
||||
repo_id = checkpoint.cached_repo.id
|
||||
repo = self.github_client.get_repo(repo_id)
|
||||
# Deserialize the repository from the checkpoint
|
||||
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
|
||||
|
||||
cursor_url_callback = make_cursor_url_callback(checkpoint)
|
||||
|
||||
repo_external_access: ExternalAccess | None = None
|
||||
if include_permissions:
|
||||
repo_external_access = get_external_access_permission(
|
||||
repo, self.github_client
|
||||
)
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
|
||||
@@ -603,7 +591,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
):
|
||||
continue
|
||||
try:
|
||||
yield _convert_pr_to_document(cast(PullRequest, pr))
|
||||
yield _convert_pr_to_document(
|
||||
cast(PullRequest, pr), repo_external_access
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
@@ -653,6 +643,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
self.github_client,
|
||||
)
|
||||
)
|
||||
logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
num_issues = 0
|
||||
@@ -678,7 +669,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
continue
|
||||
|
||||
try:
|
||||
yield _convert_issue_to_document(issue)
|
||||
yield _convert_issue_to_document(issue, repo_external_access)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
@@ -715,12 +706,16 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.reset()
|
||||
|
||||
logger.info(f"{len(checkpoint.cached_repo_ids)} repos remaining")
|
||||
if checkpoint.cached_repo_ids:
|
||||
logger.info(
|
||||
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
|
||||
)
|
||||
else:
|
||||
logger.info("No more repos remaining")
|
||||
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
def _load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
@@ -741,7 +736,32 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
adjusted_start_datetime = epoch
|
||||
|
||||
return self._fetch_from_github(
|
||||
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||
checkpoint,
|
||||
start=adjusted_start_datetime,
|
||||
end=end_datetime,
|
||||
include_permissions=include_permissions,
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=False
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
@@ -775,6 +795,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{repo_name}"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
valid_repos = True
|
||||
# If at least one repo is valid, we can proceed
|
||||
@@ -882,7 +905,6 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
|
||||
# Initialize the connector
|
||||
connector = GithubConnector(
|
||||
@@ -893,6 +915,12 @@ if __name__ == "__main__":
|
||||
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||
)
|
||||
|
||||
if connector.github_client:
|
||||
get_external_access_permission(
|
||||
connector.get_github_repos(connector.github_client).pop(),
|
||||
connector.github_client,
|
||||
)
|
||||
|
||||
# Create a time range from epoch to now
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
17
backend/onyx/connectors/github/models.py
Normal file
17
backend/onyx/connectors/github/models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
from github import Repository
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
25
backend/onyx/connectors/github/rate_limit_utils.py
Normal file
25
backend/onyx/connectors/github/rate_limit_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from github import Github
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
"""
|
||||
Sleep until the GitHub rate limit resets.
|
||||
|
||||
Args:
|
||||
github_client: The GitHub client that hit the rate limit
|
||||
"""
|
||||
sleep_time = github_client.get_rate_limit().core.reset.replace(
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
time.sleep(sleep_time.total_seconds())
|
||||
63
backend/onyx/connectors/github/utils.py
Normal file
63
backend/onyx/connectors/github/utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from github import Github
|
||||
from github.Repository import Repository
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_external_access_permission(
|
||||
repo: Repository, github_client: Github
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access permission for a repository.
|
||||
This functionality requires Enterprise Edition.
|
||||
"""
|
||||
# Check if EE is enabled
|
||||
if not global_version.is_ee_version():
|
||||
# For the MIT version, return an empty ExternalAccess (private document)
|
||||
return ExternalAccess.empty()
|
||||
|
||||
# Fetch the EE implementation
|
||||
ee_get_external_access_permission = cast(
|
||||
Callable[[Repository, Github, bool], ExternalAccess],
|
||||
fetch_versioned_implementation(
|
||||
"onyx.external_permissions.github.utils",
|
||||
"get_external_access_permission",
|
||||
),
|
||||
)
|
||||
|
||||
return ee_get_external_access_permission(repo, github_client, True)
|
||||
|
||||
|
||||
def deserialize_repository(
|
||||
cached_repo: SerializedRepository, github_client: Github
|
||||
) -> Repository:
|
||||
"""
|
||||
Deserialize a SerializedRepository back into a Repository object.
|
||||
"""
|
||||
# Try to access the requester - different PyGithub versions may use different attribute names
|
||||
try:
|
||||
# Try to get the requester using getattr to avoid linter errors
|
||||
requester = getattr(github_client, "_requester", None)
|
||||
if requester is None:
|
||||
requester = getattr(github_client, "_Github__requester", None)
|
||||
if requester is None:
|
||||
# If we can't find the requester attribute, we need to fall back to recreating the repo
|
||||
raise AttributeError("Could not find requester attribute")
|
||||
|
||||
return cached_repo.to_Repository(requester)
|
||||
except Exception as e:
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logger.warning(
|
||||
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
|
||||
)
|
||||
repo_id = cached_repo.id
|
||||
return github_client.get_repo(repo_id)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user